From b1d8767a0ca0fcc662854c2850c93ea99b463925 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Wed, 6 Jul 2022 21:17:00 +0200 Subject: [PATCH] rpcperms: intercept errors too --- rpcperms/interceptor.go | 43 ++++++++++++-- rpcperms/middleware_handler.go | 104 +++++++++++++++++++++++---------- 2 files changed, 110 insertions(+), 37 deletions(-) diff --git a/rpcperms/interceptor.go b/rpcperms/interceptor.go index c5ebc24bb..07ee4fccb 100644 --- a/rpcperms/interceptor.go +++ b/rpcperms/interceptor.go @@ -817,14 +817,27 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn return nil, err } - resp, respErr := handler(ctx, req) - if respErr != nil { - return resp, respErr + // Call the handler, which executes the request against lnd. + lndResp, lndErr := handler(ctx, req) + if lndErr != nil { + // The call to lnd ended in an error and not a normal + // proto message response. Send the error to the + // interceptor as well to inform about the abnormal + // termination of the stream and to give the option to + // replace the error message with a custom one. + replacedErr, err := r.interceptMessage( + ctx, TypeResponse, requestID, false, + info.FullMethod, lndErr, + ) + if err != nil { + return nil, err + } + return lndResp, replacedErr.(error) } return r.interceptMessage( ctx, TypeResponse, requestID, false, info.FullMethod, - resp, + lndResp, ) } } @@ -883,7 +896,27 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer interceptor: r, } - return handler(srv, wrappedSS) + // Call the stream handler, which will block as long as the + // stream is alive. + lndErr := handler(srv, wrappedSS) + if lndErr != nil { + // This is an error being returned from lnd. Send it to + // the interceptor as well to inform about the abnormal + // termination of the stream and to give the option to + // replace the error message with a custom one. + replacedErr, err := r.interceptMessage( + ss.Context(), TypeResponse, requestID, + true, info.FullMethod, lndErr, + ) + if err != nil { + return err + } + + return replacedErr.(error) + } + + // Normal/successful termination of the stream. + return nil } } diff --git a/rpcperms/middleware_handler.go b/rpcperms/middleware_handler.go index 7ccad0b58..1e9d16b56 100644 --- a/rpcperms/middleware_handler.go +++ b/rpcperms/middleware_handler.go @@ -263,24 +263,41 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error, break } - // For intercepted messages we also allow the - // content itself to be overwritten. - if t.ReplaceResponse { - response.replace = true - protoMsg, err := parseProto( - requestInfo.request.ProtoTypeName, - t.ReplacementSerialized, + // If there's nothing to replace, we're done, + // this request was just accepted. + if !t.ReplaceResponse { + break + } + + // We are replacing the response, the question + // now just is: was it an error or a proper + // proto message? + response.replace = true + if requestInfo.request.IsError { + response.replacement = errors.New( + string(t.ReplacementSerialized), ) - if err != nil { - response.err = err - - break - } - - response.replacement = protoMsg + break } + // Not an error but a proper proto message that + // needs to be replaced. For that we need to + // parse it from the raw bytes into the full RPC + // message. + protoMsg, err := parseProto( + requestInfo.request.ProtoTypeName, + t.ReplacementSerialized, + ) + + if err != nil { + response.err = err + + break + } + + response.replacement = protoMsg + default: return fmt.Errorf("unknown middleware "+ "message: %v", msg) @@ -369,6 +386,10 @@ type InterceptionRequest struct { // ProtoTypeName is the fully qualified name of the protobuf type of the // request or response message that is serialized in the field above. ProtoTypeName string + + // IsError indicates that the message contained within this request is + // an error. Will only ever be true for response messages. + IsError bool } // NewMessageInterceptionRequest creates a new interception request for either @@ -382,24 +403,36 @@ func NewMessageInterceptionRequest(ctx context.Context, return nil, err } - rpcReq, ok := m.(proto.Message) - if !ok { - return nil, fmt.Errorf("msg is not proto message: %v", m) - } - rawRequest, err := proto.Marshal(rpcReq) - if err != nil { - return nil, fmt.Errorf("cannot marshal proto msg: %v", err) + req := &InterceptionRequest{ + Type: authType, + StreamRPC: isStream, + Macaroon: mac, + RawMacaroon: rawMacaroon, + FullURI: fullMethod, } - return &InterceptionRequest{ - Type: authType, - StreamRPC: isStream, - Macaroon: mac, - RawMacaroon: rawMacaroon, - FullURI: fullMethod, - ProtoSerialized: rawRequest, - ProtoTypeName: string(proto.MessageName(rpcReq)), - }, nil + // The message is either a proto message or an error, we don't support + // any other types being intercepted. + switch t := m.(type) { + case proto.Message: + req.ProtoSerialized, err = proto.Marshal(t) + if err != nil { + return nil, fmt.Errorf("cannot marshal proto msg: %v", + err) + } + req.ProtoTypeName = string(proto.MessageName(t)) + + case error: + req.ProtoSerialized = []byte(t.Error()) + req.ProtoTypeName = "error" + req.IsError = true + + default: + return nil, fmt.Errorf("unsupported type for interception "+ + "request: %v", m) + } + + return req, nil } // NewStreamAuthInterceptionRequest creates a new interception request for a @@ -484,6 +517,7 @@ func (r *InterceptionRequest) ToRPC(requestID, StreamRpc: r.StreamRPC, TypeName: r.ProtoTypeName, Serialized: r.ProtoSerialized, + IsError: r.IsError, }, } @@ -549,8 +583,14 @@ func replaceProtoMsg(target interface{}, replacement interface{}) error { return fmt.Errorf("replacement message is of wrong type") } - proto.Reset(targetMsg) - proto.Merge(targetMsg, replacementMsg) + replacementBytes, err := proto.Marshal(replacementMsg) + if err != nil { + return fmt.Errorf("error marshaling replacement: %v", err) + } + err = proto.Unmarshal(replacementBytes, targetMsg) + if err != nil { + return fmt.Errorf("error unmarshaling replacement: %v", err) + } return nil }