diff --git a/rpcperms/interceptor.go b/rpcperms/interceptor.go index 861d94b51..c5ebc24bb 100644 --- a/rpcperms/interceptor.go +++ b/rpcperms/interceptor.go @@ -118,7 +118,7 @@ var ( // | Macaroon Interceptor | // +----------------------------------+--------> +---------------------+ // | RPC Macaroon Middleware Handler |<-------- | External Middleware | -// +----------------------------------+ | - approve request | +// +----------------------------------+ | - modify request | // | Prometheus Interceptor | +---------------------+ // +-+--------------------------------+ // | validated gRPC request from client @@ -808,15 +808,11 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn return handler(ctx, req) } - msg, err := NewMessageInterceptionRequest( - ctx, TypeRequest, false, info.FullMethod, req, - ) - if err != nil { - return nil, err - } - requestID := atomic.AddUint64(&r.lastRequestID, 1) - err = r.acceptRequest(requestID, msg) + req, err := r.interceptMessage( + ctx, TypeRequest, requestID, false, info.FullMethod, + req, + ) if err != nil { return nil, err } @@ -826,8 +822,9 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn return resp, respErr } - return r.interceptResponse( - ctx, requestID, false, info.FullMethod, resp, + return r.interceptMessage( + ctx, TypeResponse, requestID, false, info.FullMethod, + resp, ) } } @@ -874,7 +871,7 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer } requestID := atomic.AddUint64(&r.lastRequestID, 1) - err = r.acceptRequest(requestID, msg) + err = r.acceptStream(requestID, msg) if err != nil { return err } @@ -926,11 +923,11 @@ func (r *InterceptorChain) middlewareRegistered() bool { return len(r.registeredMiddleware) > 0 } -// acceptRequest sends an intercept request to all middlewares that have +// acceptStream sends an intercept request to all middlewares that have // registered for it. This means either a middleware has requested read-only // access or the request actually has a macaroon with a caveat the middleware // registered for. -func (r *InterceptorChain) acceptRequest(requestID uint64, +func (r *InterceptorChain) acceptStream(requestID uint64, msg *InterceptionRequest) error { r.RLock() @@ -967,13 +964,13 @@ func (r *InterceptorChain) acceptRequest(requestID uint64, return nil } -// interceptResponse sends out an intercept request for an RPC response. Since +// interceptMessage sends out an intercept request for an RPC response. Since // middleware that hasn't registered for the read-only mode has the option to -// overwrite/replace the response, this needs to be handled differently than the -// request/auth path above. -func (r *InterceptorChain) interceptResponse(ctx context.Context, - requestID uint64, isStream bool, fullMethod string, - m interface{}) (interface{}, error) { +// overwrite/replace the message, this needs to be handled differently than the +// auth path above. +func (r *InterceptorChain) interceptMessage(ctx context.Context, + interceptType InterceptType, requestID uint64, isStream bool, + fullMethod string, m interface{}) (interface{}, error) { r.RLock() defer r.RUnlock() @@ -981,7 +978,8 @@ func (r *InterceptorChain) interceptResponse(ctx context.Context, currentMessage := m for _, middleware := range r.registeredMiddleware { msg, err := NewMessageInterceptionRequest( - ctx, TypeResponse, isStream, fullMethod, currentMessage, + ctx, interceptType, isStream, fullMethod, + currentMessage, ) if err != nil { return nil, err @@ -1039,8 +1037,9 @@ type serverStreamWrapper struct { // SendMsg is called when lnd sends a message to the client. This is wrapped to // intercept streaming RPC responses. func (w *serverStreamWrapper) SendMsg(m interface{}) error { - newMsg, err := w.interceptor.interceptResponse( - w.ServerStream.Context(), w.requestID, true, w.fullMethod, m, + newMsg, err := w.interceptor.interceptMessage( + w.ServerStream.Context(), TypeResponse, w.requestID, true, + w.fullMethod, m, ) if err != nil { return err @@ -1057,13 +1056,13 @@ func (w *serverStreamWrapper) RecvMsg(m interface{}) error { return err } - msg, err := NewMessageInterceptionRequest( - w.ServerStream.Context(), TypeRequest, true, w.fullMethod, - m, + req, err := w.interceptor.interceptMessage( + w.ServerStream.Context(), TypeRequest, w.requestID, true, + w.fullMethod, m, ) if err != nil { return err } - return w.interceptor.acceptRequest(w.requestID, msg) + return replaceProtoMsg(m, req) } diff --git a/rpcperms/middleware_handler.go b/rpcperms/middleware_handler.go index 0a890d2c3..7ccad0b58 100644 --- a/rpcperms/middleware_handler.go +++ b/rpcperms/middleware_handler.go @@ -263,11 +263,9 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error, break } - // For intercepted responses we also allow the + // For intercepted messages we also allow the // content itself to be overwritten. - if requestInfo.request.Type == TypeResponse && - t.ReplaceResponse { - + if t.ReplaceResponse { response.replace = true protoMsg, err := parseProto( requestInfo.request.ProtoTypeName, @@ -322,7 +320,8 @@ const ( // TypeRequest is the type of intercept message that is sent when an RPC // request message is sent to lnd. For client-streaming RPCs a new // message of this type is sent for each individual RPC request sent to - // the stream. + // the stream. Middleware has the option to modify a request message + // before it is delivered to lnd. TypeRequest InterceptType = 2 // TypeResponse is the type of intercept message that is sent when an