mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-25 21:21:33 +02:00
rpcperms: add unique request ID
This commit adds a unique request ID that is the same for each gRPC request and response intercept message or each request/response message of a gRPC stream.
This commit is contained in:
@@ -545,6 +545,12 @@ func (h *middlewareHarness) interceptUnary(methodURI string,
|
|||||||
res := respIntercept.GetResponse()
|
res := respIntercept.GetResponse()
|
||||||
require.NotNil(h.t, res)
|
require.NotNil(h.t, res)
|
||||||
|
|
||||||
|
// We expect the request ID to be the same for the request intercept
|
||||||
|
// and the response intercept messages. But the message IDs must be
|
||||||
|
// different/unique.
|
||||||
|
require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId)
|
||||||
|
require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId)
|
||||||
|
|
||||||
// We need to accept the response as well.
|
// We need to accept the response as well.
|
||||||
h.sendAccept(respIntercept.MsgId, responseReplacement)
|
h.sendAccept(respIntercept.MsgId, responseReplacement)
|
||||||
|
|
||||||
@@ -593,6 +599,15 @@ func (h *middlewareHarness) interceptStream(methodURI string,
|
|||||||
res := respIntercept.GetResponse()
|
res := respIntercept.GetResponse()
|
||||||
require.NotNil(h.t, res)
|
require.NotNil(h.t, res)
|
||||||
|
|
||||||
|
// We expect the request ID to be the same for the auth intercept,
|
||||||
|
// request intercept and the response intercept messages. But the
|
||||||
|
// message IDs must be different/unique.
|
||||||
|
require.Equal(h.t, authIntercept.RequestId, respIntercept.RequestId)
|
||||||
|
require.Equal(h.t, reqIntercept.RequestId, respIntercept.RequestId)
|
||||||
|
require.NotEqual(h.t, authIntercept.MsgId, reqIntercept.MsgId)
|
||||||
|
require.NotEqual(h.t, authIntercept.MsgId, respIntercept.MsgId)
|
||||||
|
require.NotEqual(h.t, reqIntercept.MsgId, respIntercept.MsgId)
|
||||||
|
|
||||||
// We need to accept the response as well.
|
// We need to accept the response as well.
|
||||||
h.sendAccept(respIntercept.MsgId, responseReplacement)
|
h.sendAccept(respIntercept.MsgId, responseReplacement)
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/btcsuite/btclog"
|
"github.com/btcsuite/btclog"
|
||||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
@@ -134,6 +135,12 @@ var (
|
|||||||
// | edited gRPC request to client
|
// | edited gRPC request to client
|
||||||
// v
|
// v
|
||||||
type InterceptorChain struct {
|
type InterceptorChain struct {
|
||||||
|
// lastRequestID is the ID of the last gRPC request or stream that was
|
||||||
|
// intercepted by the middleware interceptor.
|
||||||
|
//
|
||||||
|
// NOTE: Must be used atomically!
|
||||||
|
lastRequestID uint64
|
||||||
|
|
||||||
// Required by the grpc-gateway/v2 library for forward compatibility.
|
// Required by the grpc-gateway/v2 library for forward compatibility.
|
||||||
lnrpc.UnimplementedStateServer
|
lnrpc.UnimplementedStateServer
|
||||||
|
|
||||||
@@ -790,7 +797,8 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.acceptRequest(msg)
|
requestID := atomic.AddUint64(&r.lastRequestID, 1)
|
||||||
|
err = r.acceptRequest(requestID, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -800,7 +808,9 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
|
|||||||
return resp, respErr
|
return resp, respErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.interceptResponse(ctx, false, info.FullMethod, resp)
|
return r.interceptResponse(
|
||||||
|
ctx, requestID, false, info.FullMethod, resp,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -845,13 +855,15 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.acceptRequest(msg)
|
requestID := atomic.AddUint64(&r.lastRequestID, 1)
|
||||||
|
err = r.acceptRequest(requestID, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
wrappedSS := &serverStreamWrapper{
|
wrappedSS := &serverStreamWrapper{
|
||||||
ServerStream: ss,
|
ServerStream: ss,
|
||||||
|
requestID: requestID,
|
||||||
fullMethod: info.FullMethod,
|
fullMethod: info.FullMethod,
|
||||||
interceptor: r,
|
interceptor: r,
|
||||||
}
|
}
|
||||||
@@ -900,7 +912,9 @@ func (r *InterceptorChain) middlewareRegistered() bool {
|
|||||||
// registered for it. This means either a middleware has requested read-only
|
// registered for it. This means either a middleware has requested read-only
|
||||||
// access or the request actually has a macaroon which a caveat the middleware
|
// access or the request actually has a macaroon which a caveat the middleware
|
||||||
// registered for.
|
// registered for.
|
||||||
func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
|
func (r *InterceptorChain) acceptRequest(requestID uint64,
|
||||||
|
msg *InterceptionRequest) error {
|
||||||
|
|
||||||
r.RLock()
|
r.RLock()
|
||||||
defer r.RUnlock()
|
defer r.RUnlock()
|
||||||
|
|
||||||
@@ -915,7 +929,7 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := middleware.intercept(msg)
|
resp, err := middleware.intercept(requestID, msg)
|
||||||
|
|
||||||
// Error during interception itself.
|
// Error during interception itself.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -936,7 +950,8 @@ func (r *InterceptorChain) acceptRequest(msg *InterceptionRequest) error {
|
|||||||
// overwrite/replace the response, this needs to be handled differently than the
|
// overwrite/replace the response, this needs to be handled differently than the
|
||||||
// request/auth path above.
|
// request/auth path above.
|
||||||
func (r *InterceptorChain) interceptResponse(ctx context.Context,
|
func (r *InterceptorChain) interceptResponse(ctx context.Context,
|
||||||
isStream bool, fullMethod string, m interface{}) (interface{}, error) {
|
requestID uint64, isStream bool, fullMethod string,
|
||||||
|
m interface{}) (interface{}, error) {
|
||||||
|
|
||||||
r.RLock()
|
r.RLock()
|
||||||
defer r.RUnlock()
|
defer r.RUnlock()
|
||||||
@@ -960,7 +975,7 @@ func (r *InterceptorChain) interceptResponse(ctx context.Context,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := middleware.intercept(msg)
|
resp, err := middleware.intercept(requestID, msg)
|
||||||
|
|
||||||
// Error during interception itself.
|
// Error during interception itself.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -988,6 +1003,8 @@ type serverStreamWrapper struct {
|
|||||||
// ServerStream is the stream that's being wrapped.
|
// ServerStream is the stream that's being wrapped.
|
||||||
grpc.ServerStream
|
grpc.ServerStream
|
||||||
|
|
||||||
|
requestID uint64
|
||||||
|
|
||||||
fullMethod string
|
fullMethod string
|
||||||
|
|
||||||
interceptor *InterceptorChain
|
interceptor *InterceptorChain
|
||||||
@@ -997,7 +1014,7 @@ type serverStreamWrapper struct {
|
|||||||
// intercept streaming RPC responses.
|
// intercept streaming RPC responses.
|
||||||
func (w *serverStreamWrapper) SendMsg(m interface{}) error {
|
func (w *serverStreamWrapper) SendMsg(m interface{}) error {
|
||||||
newMsg, err := w.interceptor.interceptResponse(
|
newMsg, err := w.interceptor.interceptResponse(
|
||||||
w.ServerStream.Context(), true, w.fullMethod, m,
|
w.ServerStream.Context(), w.requestID, true, w.fullMethod, m,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1022,5 +1039,5 @@ func (w *serverStreamWrapper) RecvMsg(m interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.interceptor.acceptRequest(msg)
|
return w.interceptor.acceptRequest(w.requestID, msg)
|
||||||
}
|
}
|
||||||
|
@@ -107,12 +107,13 @@ func NewMiddlewareHandler(name, customCaveatName string, readOnly bool,
|
|||||||
// feedback on it and sending the feedback to the appropriate channel. All steps
|
// feedback on it and sending the feedback to the appropriate channel. All steps
|
||||||
// are guarded by the configured timeout to make sure a middleware cannot slow
|
// are guarded by the configured timeout to make sure a middleware cannot slow
|
||||||
// down requests too much.
|
// down requests too much.
|
||||||
func (h *MiddlewareHandler) intercept(
|
func (h *MiddlewareHandler) intercept(requestID uint64,
|
||||||
req *InterceptionRequest) (*interceptResponse, error) {
|
req *InterceptionRequest) (*interceptResponse, error) {
|
||||||
|
|
||||||
respChan := make(chan *interceptResponse, 1)
|
respChan := make(chan *interceptResponse, 1)
|
||||||
|
|
||||||
newRequest := &interceptRequest{
|
newRequest := &interceptRequest{
|
||||||
|
requestID: requestID,
|
||||||
request: req,
|
request: req,
|
||||||
response: respChan,
|
response: respChan,
|
||||||
}
|
}
|
||||||
@@ -233,7 +234,9 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
|
|||||||
req := newRequest.request
|
req := newRequest.request
|
||||||
interceptRequests[msgID] = newRequest
|
interceptRequests[msgID] = newRequest
|
||||||
|
|
||||||
interceptReq, err := req.ToRPC(msgID)
|
interceptReq, err := req.ToRPC(
|
||||||
|
newRequest.requestID, msgID,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -447,10 +450,11 @@ func macaroonFromContext(ctx context.Context) (*macaroon.Macaroon, []byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ToRPC converts the interception request to its RPC counterpart.
|
// ToRPC converts the interception request to its RPC counterpart.
|
||||||
func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest,
|
func (r *InterceptionRequest) ToRPC(requestID,
|
||||||
error) {
|
msgID uint64) (*lnrpc.RPCMiddlewareRequest, error) {
|
||||||
|
|
||||||
rpcRequest := &lnrpc.RPCMiddlewareRequest{
|
rpcRequest := &lnrpc.RPCMiddlewareRequest{
|
||||||
|
RequestId: requestID,
|
||||||
MsgId: msgID,
|
MsgId: msgID,
|
||||||
RawMacaroon: r.RawMacaroon,
|
RawMacaroon: r.RawMacaroon,
|
||||||
CustomCaveatCondition: r.CustomCaveatCondition,
|
CustomCaveatCondition: r.CustomCaveatCondition,
|
||||||
@@ -495,6 +499,7 @@ func (r *InterceptionRequest) ToRPC(msgID uint64) (*lnrpc.RPCMiddlewareRequest,
|
|||||||
// out to a middleware and the response that is eventually sent back by the
|
// out to a middleware and the response that is eventually sent back by the
|
||||||
// middleware.
|
// middleware.
|
||||||
type interceptRequest struct {
|
type interceptRequest struct {
|
||||||
|
requestID uint64
|
||||||
request *InterceptionRequest
|
request *InterceptionRequest
|
||||||
response chan *interceptResponse
|
response chan *interceptResponse
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user