diff --git a/routing/shards/shard_tracker.go b/routing/shards/shard_tracker.go new file mode 100644 index 000000000..474c85bcc --- /dev/null +++ b/routing/shards/shard_tracker.go @@ -0,0 +1,126 @@ +package shards + +import ( + "fmt" + + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/record" +) + +// PaymentShard is an interface representing a shard tracked by the +// ShardTracker. It contains options that are specific to the given shard that +// might differ from the overall payment. +type PaymentShard interface { + // Hash returns the hash used for the HTLC representing this shard. + Hash() lntypes.Hash + + // MPP returns any extra MPP records that should be set for the final + // hop on the route used by this shard. + MPP() *record.MPP + + // AMP returns any extra AMP records that should be set for the final + // hop on the route used by this shard. + AMP() *record.AMP +} + +// ShardTracker is an interfae representing a tracker that keeps track of the +// inflight shards of a payment, and is able to assign new shards the correct +// options such as hash and extra records. +type ShardTracker interface { + // NewShard registers a new attempt with the ShardTracker and returns a + // new shard representing this attempt. This attempt's shard should be + // canceled if it ends up not being used by the overall payment, i.e. + // if the attempt fails. + NewShard(uint64, bool) (PaymentShard, error) + + // CancelShard cancel's the shard corresponding to the given attempt + // ID. This lets the ShardTracker free up any slots used by this shard, + // and in case of AMP payments return the share used by this shard to + // the root share. + CancelShard(uint64) error + + // GetHash retrieves the hash used by the shard of the given attempt + // ID. This wil return an error if the attempt ID is unknown. + GetHash(uint64) (lntypes.Hash, error) +} + +// Shard is a struct used for simple shards where we obly need to keep map it +// to a single hash. +type Shard struct { + hash lntypes.Hash +} + +// Hash returns the hash used for the HTLC representing this shard. +func (s *Shard) Hash() lntypes.Hash { + return s.hash +} + +// MPP returns any extra MPP records that should be set for the final hop on +// the route used by this shard. +func (s *Shard) MPP() *record.MPP { + return nil +} + +// AMP returns any extra AMP records that should be set for the final hop on +// the route used by this shard. +func (s *Shard) AMP() *record.AMP { + return nil +} + +// SimpleShardTracker is an implementation of the ShardTracker interface that +// simply maps attempt IDs to hashes. New shards will be given a static payment +// hash. This should be used for regular and MPP payments, in addition to +// resumed payments where all the attempt's hashes have already been created. +type SimpleShardTracker struct { + hash lntypes.Hash + shards map[uint64]lntypes.Hash +} + +// A compile time check to ensure SimpleShardTracker implements the +// ShardTracker interface. +var _ ShardTracker = (*SimpleShardTracker)(nil) + +// NewSimpleShardTracker creates a new intance of the SimpleShardTracker with +// the given payment hash and existing attempts. +func NewSimpleShardTracker(paymentHash lntypes.Hash, + shards map[uint64]lntypes.Hash) ShardTracker { + + if shards == nil { + shards = make(map[uint64]lntypes.Hash) + } + + return &SimpleShardTracker{ + hash: paymentHash, + shards: shards, + } +} + +// NewShard registers a new attempt with the ShardTracker and returns a +// new shard representing this attempt. This attempt's shard should be canceled +// if it ends up not being used by the overall payment, i.e. if the attempt +// fails. +func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) { + m.shards[id] = m.hash + + return &Shard{ + hash: m.hash, + }, nil +} + +// CancelShard cancel's the shard corresponding to the given attempt ID. +func (m *SimpleShardTracker) CancelShard(id uint64) error { + delete(m.shards, id) + return nil +} + +// GetHash retrieves the hash used by the shard of the given attempt ID. This +// will return an error if the attempt ID is unknown. +func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) { + hash, ok := m.shards[id] + if !ok { + return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+ + "not found", id) + } + + return hash, nil +} diff --git a/routing/shards/shard_tracker_test.go b/routing/shards/shard_tracker_test.go new file mode 100644 index 000000000..b755e91fc --- /dev/null +++ b/routing/shards/shard_tracker_test.go @@ -0,0 +1,47 @@ +package shards_test + +import ( + "crypto/rand" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/shards" + "github.com/stretchr/testify/require" +) + +// TestSimpleShardTracker tests that the simple tracker that keeps a map from +// attemptID-> payment hash works. +func TestSimpleShardTracker(t *testing.T) { + var testHashes [2]lntypes.Hash + for i := range testHashes { + _, err := rand.Read(testHashes[i][:]) + require.NoError(t, err) + } + + startAttempts := map[uint64]lntypes.Hash{ + 1: testHashes[1], + } + + tracker := shards.NewSimpleShardTracker(testHashes[0], startAttempts) + + // Trying to retrieve a hash for id 0 should result in an error. + _, err := tracker.GetHash(0) + require.Error(t, err) + + // Getting id 1 should workd. + hash1, err := tracker.GetHash(1) + require.NoError(t, err) + + require.Equal(t, testHashes[1], hash1) + + // Finally, create a new shard and immediately retrieve the hash. + shard, err := tracker.NewShard(2, false) + require.NoError(t, err) + + // It's hash should be the tracker's overall payment hash. + hash2, err := tracker.GetHash(2) + require.NoError(t, err) + + require.Equal(t, testHashes[0], shard.Hash()) + require.Equal(t, testHashes[0], hash2) +}