diff --git a/routing/additional_edge.go b/routing/additional_edge.go new file mode 100644 index 000000000..22cf3032b --- /dev/null +++ b/routing/additional_edge.go @@ -0,0 +1,107 @@ +package routing + +import ( + "errors" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + // ErrNoPayLoadSizeFunc is returned when no payload size function is + // definied. + ErrNoPayLoadSizeFunc = errors.New("no payloadSizeFunc defined for " + + "additional edge") +) + +// AdditionalEdge is an interface which specifies additional edges which can +// be appended to an existing route. Compared to normal edges of a route they +// provide an explicit payload size function and are introduced because blinded +// paths differ in their payload structure. +type AdditionalEdge interface { + // IntermediatePayloadSize returns the size of the payload for the + // additional edge when being an intermediate hop in a route NOT the + // final hop. + IntermediatePayloadSize(amount lnwire.MilliSatoshi, expiry uint32, + legacy bool, channelID uint64) uint64 + + // EdgePolicy returns the policy of the additional edge. + EdgePolicy() *models.CachedEdgePolicy +} + +// PayloadSizeFunc defines the interface for the payload size function. +type PayloadSizeFunc func(amount lnwire.MilliSatoshi, expiry uint32, + legacy bool, channelID uint64) uint64 + +// PrivateEdge implements the AdditionalEdge interface. As the name implies it +// is used for private route hints that the receiver adds for example to an +// invoice. +type PrivateEdge struct { + policy *models.CachedEdgePolicy +} + +// EdgePolicy return the policy of the PrivateEdge. +func (p *PrivateEdge) EdgePolicy() *models.CachedEdgePolicy { + return p.policy +} + +// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if +// this edge were to be included in a route. +func (p *PrivateEdge) IntermediatePayloadSize(amount lnwire.MilliSatoshi, + expiry uint32, legacy bool, channelID uint64) uint64 { + + hop := route.Hop{ + AmtToForward: amount, + OutgoingTimeLock: expiry, + LegacyPayload: legacy, + } + + return hop.PayloadSize(channelID) +} + +// BlindedEdge implements the AdditionalEdge interface. Blinded hops are viewed +// as additional edges because they are appened at the end of a normal route. +type BlindedEdge struct { + policy *models.CachedEdgePolicy + cipherText []byte + blindingPoint *btcec.PublicKey +} + +// EdgePolicy return the policy of the BlindedEdge. +func (b *BlindedEdge) EdgePolicy() *models.CachedEdgePolicy { + return b.policy +} + +// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if +// this edge were to be included in a route. +func (b *BlindedEdge) IntermediatePayloadSize(_ lnwire.MilliSatoshi, _ uint32, + _ bool, _ uint64) uint64 { + + hop := route.Hop{ + BlindingPoint: b.blindingPoint, + LegacyPayload: false, + EncryptedData: b.cipherText, + } + + // For blinded paths the next chanID is in the encrypted data tlv. + return hop.PayloadSize(0) +} + +// Compile-time constraints to ensure the PrivateEdge and the BlindedEdge +// implement the AdditionalEdge interface. +var _ AdditionalEdge = (*PrivateEdge)(nil) +var _ AdditionalEdge = (*BlindedEdge)(nil) + +// defaultHopPayloadSize is the default payload size of a normal (not-blinded) +// hop in the route. +func defaultHopPayloadSize(amount lnwire.MilliSatoshi, expiry uint32, + legacy bool, channelID uint64) uint64 { + + // The payload size of a cleartext intermediate hop is equal to the + // payload size of a private edge therefore we reuse its size function. + edge := PrivateEdge{} + + return edge.IntermediatePayloadSize(amount, expiry, legacy, channelID) +} diff --git a/routing/additional_edge_test.go b/routing/additional_edge_test.go new file mode 100644 index 000000000..b3ea3b501 --- /dev/null +++ b/routing/additional_edge_test.go @@ -0,0 +1,136 @@ +package routing + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" +) + +// TestIntermediatePayloadSize tests the payload size functions of the +// PrivateEdge and the BlindedEdge. +func TestIntermediatePayloadSize(t *testing.T) { + t.Parallel() + + testPrivKeyBytes, _ := hex.DecodeString("e126f68f7eafcc8b74f54d269fe" + + "206be715000f94dac067d1c04a8ca3b2db734") + _, blindedPoint := btcec.PrivKeyFromBytes(testPrivKeyBytes) + + testCases := []struct { + name string + hop route.Hop + nextHop uint64 + edge AdditionalEdge + }{ + { + name: "Legacy payload private edge", + hop: route.Hop{ + AmtToForward: 1000, + OutgoingTimeLock: 600000, + ChannelID: 3432483437438, + LegacyPayload: true, + }, + nextHop: 1, + edge: &PrivateEdge{}, + }, + { + name: "Tlv payload private edge", + hop: route.Hop{ + AmtToForward: 1000, + OutgoingTimeLock: 600000, + ChannelID: 3432483437438, + LegacyPayload: false, + }, + nextHop: 1, + edge: &PrivateEdge{}, + }, + { + name: "Blinded edge", + hop: route.Hop{ + EncryptedData: []byte{12, 13}, + }, + edge: &BlindedEdge{ + cipherText: []byte{12, 13}, + }, + }, + { + name: "Blinded edge - introduction point", + hop: route.Hop{ + EncryptedData: []byte{12, 13}, + BlindingPoint: blindedPoint, + }, + edge: &BlindedEdge{ + cipherText: []byte{12, 13}, + blindingPoint: blindedPoint, + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + payLoad, err := createHopPayload( + testCase.hop, testCase.nextHop, false, + ) + require.NoErrorf(t, err, "failed to create hop payload") + + expectedPayloadSize := testCase.edge. + IntermediatePayloadSize( + testCase.hop.AmtToForward, + testCase.hop.OutgoingTimeLock, + testCase.hop.LegacyPayload, + testCase.nextHop, + ) + + require.Equal( + t, expectedPayloadSize, + uint64(payLoad.NumBytes()), + ) + }) + } +} + +// createHopPayload creates the hop payload of the sphinx package to facilitate +// the testing of the payload size. +func createHopPayload(hop route.Hop, nextHop uint64, + finalHop bool) (sphinx.HopPayload, error) { + + // If this is the legacy payload, then we can just include the + // hop data as normal. + if hop.LegacyPayload { + // Before we encode this value, we'll pack the next hop + // into the NextAddress field of the hop info to ensure + // we point to the right now. + hopData := sphinx.HopData{ + ForwardAmount: uint64(hop.AmtToForward), + OutgoingCltv: hop.OutgoingTimeLock, + } + binary.BigEndian.PutUint64( + hopData.NextAddress[:], nextHop, + ) + + return sphinx.NewLegacyHopPayload(&hopData) + } + + // For non-legacy payloads, we'll need to pack the + // routing information, along with any extra TLV + // information into the new per-hop payload format. + // We'll also pass in the chan ID of the hop this + // channel should be forwarded to so we can construct a + // valid payload. + var b bytes.Buffer + err := hop.PackHopPayload(&b, nextHop, finalHop) + if err != nil { + return sphinx.HopPayload{}, err + } + + return sphinx.NewTLVHopPayload(b.Bytes()) +} diff --git a/routing/mocks.go b/routing/mocks.go new file mode 100644 index 000000000..9c019abc9 --- /dev/null +++ b/routing/mocks.go @@ -0,0 +1,32 @@ +package routing + +import ( + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/mock" +) + +// mockAdditionalEdge is a mock of the AdditionalEdge interface. +type mockAdditionalEdge struct{ mock.Mock } + +// IntermediatePayloadSize returns the sphinx payload size defined in BOLT04 if +// this edge were to be included in a route. +func (m *mockAdditionalEdge) IntermediatePayloadSize(amount lnwire.MilliSatoshi, + expiry uint32, legacy bool, channelID uint64) uint64 { + + args := m.Called(amount, expiry, legacy, channelID) + + return args.Get(0).(uint64) +} + +// EdgePolicy return the policy of the mockAdditionalEdge. +func (m *mockAdditionalEdge) EdgePolicy() *models.CachedEdgePolicy { + args := m.Called() + + edgePolicy := args.Get(0) + if edgePolicy == nil { + return nil + } + + return edgePolicy.(*models.CachedEdgePolicy) +}