diff --git a/lnwire/custom_records.go b/lnwire/custom_records.go index f0f59185e..8177cbe82 100644 --- a/lnwire/custom_records.go +++ b/lnwire/custom_records.go @@ -89,6 +89,22 @@ func (c CustomRecords) Copy() CustomRecords { return customRecords } +// MergedCopy creates a copy of the records and merges them with the given +// records. If the same key is present in both sets, the value from the other +// records will be used. +func (c CustomRecords) MergedCopy(other CustomRecords) CustomRecords { + copiedRecords := make(CustomRecords, len(c)) + for k, v := range c { + copiedRecords[k] = v + } + + for k, v := range other { + copiedRecords[k] = v + } + + return copiedRecords +} + // ExtendRecordProducers extends the given records slice with the custom // records. The resultant records slice will be sorted if the given records // slice contains TLV types greater than or equal to MinCustomRecordsTlvType. diff --git a/lnwire/custom_records_test.go b/lnwire/custom_records_test.go index 1d30e2100..8ff6af10b 100644 --- a/lnwire/custom_records_test.go +++ b/lnwire/custom_records_test.go @@ -6,6 +6,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -194,3 +195,54 @@ func serializeRecordProducers(t *testing.T, return b.Bytes() } + +func TestCustomRecordsMergedCopy(t *testing.T) { + tests := []struct { + name string + c CustomRecords + other CustomRecords + want CustomRecords + }{ + { + name: "nil records", + want: make(CustomRecords), + }, + { + name: "empty records", + c: make(CustomRecords), + other: make(CustomRecords), + want: make(CustomRecords), + }, + { + name: "distinct records", + c: CustomRecords{ + 1: {1, 2, 3}, + }, + other: CustomRecords{ + 2: {4, 5, 6}, + }, + want: CustomRecords{ + 1: {1, 2, 3}, + 2: {4, 5, 6}, + }, + }, + { + name: "same records, different values", + c: CustomRecords{ + 1: {1, 2, 3}, + }, + other: CustomRecords{ + 1: {4, 5, 6}, + }, + want: CustomRecords{ + 1: {4, 5, 6}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.c.MergedCopy(tt.other) + assert.Equal(t, tt.want, result) + }) + } +}