rpcperms: intercept errors too

This commit is contained in:
Oliver Gugger
2022-07-06 21:17:00 +02:00
parent 502542da60
commit b1d8767a0c
2 changed files with 110 additions and 37 deletions

View File

@ -817,14 +817,27 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
return nil, err return nil, err
} }
resp, respErr := handler(ctx, req) // Call the handler, which executes the request against lnd.
if respErr != nil { lndResp, lndErr := handler(ctx, req)
return resp, respErr if lndErr != nil {
// The call to lnd ended in an error and not a normal
// proto message response. Send the error to the
// interceptor as well to inform about the abnormal
// termination of the stream and to give the option to
// replace the error message with a custom one.
replacedErr, err := r.interceptMessage(
ctx, TypeResponse, requestID, false,
info.FullMethod, lndErr,
)
if err != nil {
return nil, err
}
return lndResp, replacedErr.(error)
} }
return r.interceptMessage( return r.interceptMessage(
ctx, TypeResponse, requestID, false, info.FullMethod, ctx, TypeResponse, requestID, false, info.FullMethod,
resp, lndResp,
) )
} }
} }
@ -883,7 +896,27 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
interceptor: r, interceptor: r,
} }
return handler(srv, wrappedSS) // Call the stream handler, which will block as long as the
// stream is alive.
lndErr := handler(srv, wrappedSS)
if lndErr != nil {
// This is an error being returned from lnd. Send it to
// the interceptor as well to inform about the abnormal
// termination of the stream and to give the option to
// replace the error message with a custom one.
replacedErr, err := r.interceptMessage(
ss.Context(), TypeResponse, requestID,
true, info.FullMethod, lndErr,
)
if err != nil {
return err
}
return replacedErr.(error)
}
// Normal/successful termination of the stream.
return nil
} }
} }

View File

@ -263,24 +263,41 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
break break
} }
// For intercepted messages we also allow the // If there's nothing to replace, we're done,
// content itself to be overwritten. // this request was just accepted.
if t.ReplaceResponse { if !t.ReplaceResponse {
response.replace = true break
protoMsg, err := parseProto( }
requestInfo.request.ProtoTypeName,
t.ReplacementSerialized, // We are replacing the response, the question
// now just is: was it an error or a proper
// proto message?
response.replace = true
if requestInfo.request.IsError {
response.replacement = errors.New(
string(t.ReplacementSerialized),
) )
if err != nil { break
response.err = err
break
}
response.replacement = protoMsg
} }
// Not an error but a proper proto message that
// needs to be replaced. For that we need to
// parse it from the raw bytes into the full RPC
// message.
protoMsg, err := parseProto(
requestInfo.request.ProtoTypeName,
t.ReplacementSerialized,
)
if err != nil {
response.err = err
break
}
response.replacement = protoMsg
default: default:
return fmt.Errorf("unknown middleware "+ return fmt.Errorf("unknown middleware "+
"message: %v", msg) "message: %v", msg)
@ -369,6 +386,10 @@ type InterceptionRequest struct {
// ProtoTypeName is the fully qualified name of the protobuf type of the // ProtoTypeName is the fully qualified name of the protobuf type of the
// request or response message that is serialized in the field above. // request or response message that is serialized in the field above.
ProtoTypeName string ProtoTypeName string
// IsError indicates that the message contained within this request is
// an error. Will only ever be true for response messages.
IsError bool
} }
// NewMessageInterceptionRequest creates a new interception request for either // NewMessageInterceptionRequest creates a new interception request for either
@ -382,24 +403,36 @@ func NewMessageInterceptionRequest(ctx context.Context,
return nil, err return nil, err
} }
rpcReq, ok := m.(proto.Message) req := &InterceptionRequest{
if !ok { Type: authType,
return nil, fmt.Errorf("msg is not proto message: %v", m) StreamRPC: isStream,
} Macaroon: mac,
rawRequest, err := proto.Marshal(rpcReq) RawMacaroon: rawMacaroon,
if err != nil { FullURI: fullMethod,
return nil, fmt.Errorf("cannot marshal proto msg: %v", err)
} }
return &InterceptionRequest{ // The message is either a proto message or an error, we don't support
Type: authType, // any other types being intercepted.
StreamRPC: isStream, switch t := m.(type) {
Macaroon: mac, case proto.Message:
RawMacaroon: rawMacaroon, req.ProtoSerialized, err = proto.Marshal(t)
FullURI: fullMethod, if err != nil {
ProtoSerialized: rawRequest, return nil, fmt.Errorf("cannot marshal proto msg: %v",
ProtoTypeName: string(proto.MessageName(rpcReq)), err)
}, nil }
req.ProtoTypeName = string(proto.MessageName(t))
case error:
req.ProtoSerialized = []byte(t.Error())
req.ProtoTypeName = "error"
req.IsError = true
default:
return nil, fmt.Errorf("unsupported type for interception "+
"request: %v", m)
}
return req, nil
} }
// NewStreamAuthInterceptionRequest creates a new interception request for a // NewStreamAuthInterceptionRequest creates a new interception request for a
@ -484,6 +517,7 @@ func (r *InterceptionRequest) ToRPC(requestID,
StreamRpc: r.StreamRPC, StreamRpc: r.StreamRPC,
TypeName: r.ProtoTypeName, TypeName: r.ProtoTypeName,
Serialized: r.ProtoSerialized, Serialized: r.ProtoSerialized,
IsError: r.IsError,
}, },
} }
@ -549,8 +583,14 @@ func replaceProtoMsg(target interface{}, replacement interface{}) error {
return fmt.Errorf("replacement message is of wrong type") return fmt.Errorf("replacement message is of wrong type")
} }
proto.Reset(targetMsg) replacementBytes, err := proto.Marshal(replacementMsg)
proto.Merge(targetMsg, replacementMsg) if err != nil {
return fmt.Errorf("error marshaling replacement: %v", err)
}
err = proto.Unmarshal(replacementBytes, targetMsg)
if err != nil {
return fmt.Errorf("error unmarshaling replacement: %v", err)
}
return nil return nil
} }