mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-30 10:35:32 +02:00
rpcperms: intercept errors too
This commit is contained in:
@ -817,14 +817,27 @@ func (r *InterceptorChain) middlewareUnaryServerInterceptor() grpc.UnaryServerIn
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, respErr := handler(ctx, req)
|
// Call the handler, which executes the request against lnd.
|
||||||
if respErr != nil {
|
lndResp, lndErr := handler(ctx, req)
|
||||||
return resp, respErr
|
if lndErr != nil {
|
||||||
|
// The call to lnd ended in an error and not a normal
|
||||||
|
// proto message response. Send the error to the
|
||||||
|
// interceptor as well to inform about the abnormal
|
||||||
|
// termination of the stream and to give the option to
|
||||||
|
// replace the error message with a custom one.
|
||||||
|
replacedErr, err := r.interceptMessage(
|
||||||
|
ctx, TypeResponse, requestID, false,
|
||||||
|
info.FullMethod, lndErr,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return lndResp, replacedErr.(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.interceptMessage(
|
return r.interceptMessage(
|
||||||
ctx, TypeResponse, requestID, false, info.FullMethod,
|
ctx, TypeResponse, requestID, false, info.FullMethod,
|
||||||
resp,
|
lndResp,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -883,7 +896,27 @@ func (r *InterceptorChain) middlewareStreamServerInterceptor() grpc.StreamServer
|
|||||||
interceptor: r,
|
interceptor: r,
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler(srv, wrappedSS)
|
// Call the stream handler, which will block as long as the
|
||||||
|
// stream is alive.
|
||||||
|
lndErr := handler(srv, wrappedSS)
|
||||||
|
if lndErr != nil {
|
||||||
|
// This is an error being returned from lnd. Send it to
|
||||||
|
// the interceptor as well to inform about the abnormal
|
||||||
|
// termination of the stream and to give the option to
|
||||||
|
// replace the error message with a custom one.
|
||||||
|
replacedErr, err := r.interceptMessage(
|
||||||
|
ss.Context(), TypeResponse, requestID,
|
||||||
|
true, info.FullMethod, lndErr,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return replacedErr.(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal/successful termination of the stream.
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,10 +263,28 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// For intercepted messages we also allow the
|
// If there's nothing to replace, we're done,
|
||||||
// content itself to be overwritten.
|
// this request was just accepted.
|
||||||
if t.ReplaceResponse {
|
if !t.ReplaceResponse {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// We are replacing the response, the question
|
||||||
|
// now just is: was it an error or a proper
|
||||||
|
// proto message?
|
||||||
response.replace = true
|
response.replace = true
|
||||||
|
if requestInfo.request.IsError {
|
||||||
|
response.replacement = errors.New(
|
||||||
|
string(t.ReplacementSerialized),
|
||||||
|
)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not an error but a proper proto message that
|
||||||
|
// needs to be replaced. For that we need to
|
||||||
|
// parse it from the raw bytes into the full RPC
|
||||||
|
// message.
|
||||||
protoMsg, err := parseProto(
|
protoMsg, err := parseProto(
|
||||||
requestInfo.request.ProtoTypeName,
|
requestInfo.request.ProtoTypeName,
|
||||||
t.ReplacementSerialized,
|
t.ReplacementSerialized,
|
||||||
@ -279,7 +297,6 @@ func (h *MiddlewareHandler) sendInterceptRequests(errChan chan error,
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.replacement = protoMsg
|
response.replacement = protoMsg
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown middleware "+
|
return fmt.Errorf("unknown middleware "+
|
||||||
@ -369,6 +386,10 @@ type InterceptionRequest struct {
|
|||||||
// ProtoTypeName is the fully qualified name of the protobuf type of the
|
// ProtoTypeName is the fully qualified name of the protobuf type of the
|
||||||
// request or response message that is serialized in the field above.
|
// request or response message that is serialized in the field above.
|
||||||
ProtoTypeName string
|
ProtoTypeName string
|
||||||
|
|
||||||
|
// IsError indicates that the message contained within this request is
|
||||||
|
// an error. Will only ever be true for response messages.
|
||||||
|
IsError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessageInterceptionRequest creates a new interception request for either
|
// NewMessageInterceptionRequest creates a new interception request for either
|
||||||
@ -382,24 +403,36 @@ func NewMessageInterceptionRequest(ctx context.Context,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rpcReq, ok := m.(proto.Message)
|
req := &InterceptionRequest{
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("msg is not proto message: %v", m)
|
|
||||||
}
|
|
||||||
rawRequest, err := proto.Marshal(rpcReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot marshal proto msg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &InterceptionRequest{
|
|
||||||
Type: authType,
|
Type: authType,
|
||||||
StreamRPC: isStream,
|
StreamRPC: isStream,
|
||||||
Macaroon: mac,
|
Macaroon: mac,
|
||||||
RawMacaroon: rawMacaroon,
|
RawMacaroon: rawMacaroon,
|
||||||
FullURI: fullMethod,
|
FullURI: fullMethod,
|
||||||
ProtoSerialized: rawRequest,
|
}
|
||||||
ProtoTypeName: string(proto.MessageName(rpcReq)),
|
|
||||||
}, nil
|
// The message is either a proto message or an error, we don't support
|
||||||
|
// any other types being intercepted.
|
||||||
|
switch t := m.(type) {
|
||||||
|
case proto.Message:
|
||||||
|
req.ProtoSerialized, err = proto.Marshal(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot marshal proto msg: %v",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
req.ProtoTypeName = string(proto.MessageName(t))
|
||||||
|
|
||||||
|
case error:
|
||||||
|
req.ProtoSerialized = []byte(t.Error())
|
||||||
|
req.ProtoTypeName = "error"
|
||||||
|
req.IsError = true
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported type for interception "+
|
||||||
|
"request: %v", m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStreamAuthInterceptionRequest creates a new interception request for a
|
// NewStreamAuthInterceptionRequest creates a new interception request for a
|
||||||
@ -484,6 +517,7 @@ func (r *InterceptionRequest) ToRPC(requestID,
|
|||||||
StreamRpc: r.StreamRPC,
|
StreamRpc: r.StreamRPC,
|
||||||
TypeName: r.ProtoTypeName,
|
TypeName: r.ProtoTypeName,
|
||||||
Serialized: r.ProtoSerialized,
|
Serialized: r.ProtoSerialized,
|
||||||
|
IsError: r.IsError,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -549,8 +583,14 @@ func replaceProtoMsg(target interface{}, replacement interface{}) error {
|
|||||||
return fmt.Errorf("replacement message is of wrong type")
|
return fmt.Errorf("replacement message is of wrong type")
|
||||||
}
|
}
|
||||||
|
|
||||||
proto.Reset(targetMsg)
|
replacementBytes, err := proto.Marshal(replacementMsg)
|
||||||
proto.Merge(targetMsg, replacementMsg)
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshaling replacement: %v", err)
|
||||||
|
}
|
||||||
|
err = proto.Unmarshal(replacementBytes, targetMsg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error unmarshaling replacement: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user