diff --git a/channeldb/channel.go b/channeldb/channel.go index f91c914d8..9d7d01bae 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -497,6 +497,7 @@ func (c *OpenChannel) RefreshShortChanID() error { } c.ShortChannelID = sid + c.Packager = NewChannelPackager(sid) return nil } @@ -665,6 +666,7 @@ func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { c.IsPending = false c.ShortChannelID = openLoc + c.Packager = NewChannelPackager(openLoc) return nil } @@ -1474,6 +1476,9 @@ func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { // processed, and returns their deserialized log updates in map indexed by the // remote commitment height at which the updates were locked in. func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + c.RLock() + defer c.RUnlock() + var fwdPkgs []*FwdPkg if err := c.Db.View(func(tx *bolt.Tx) error { var err error @@ -1489,6 +1494,9 @@ func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { // SetFwdFilter atomically sets the forwarding filter for the forwarding package // identified by `height`. func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + c.Lock() + defer c.Unlock() + return c.Db.Update(func(tx *bolt.Tx) error { return c.Packager.SetFwdFilter(tx, height, fwdFilter) }) @@ -1499,6 +1507,9 @@ func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { // // NOTE: This method should only be called on packages marked FwdStateCompleted. func (c *OpenChannel) RemoveFwdPkg(height uint64) error { + c.Lock() + defer c.Unlock() + return c.Db.Update(func(tx *bolt.Tx) error { return c.Packager.RemovePkg(tx, height) }) diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index f7f331af7..efd36abbd 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -898,6 +898,16 @@ func TestRefreshShortChanID(t *testing.T) { "updated before refreshing short_chan_id") } + // Now that the receiver's short channel id has been updated, check to + // ensure that the channel packager's source has been updated as well. + // This ensures that the packager will read and write to buckets + // corresponding to the new short chan id, instead of the prior. + if state.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + state.Packager.(*ChannelPackager).source) + } + // Now, refresh the short channel ID of the pending channel. err = pendingChannel.RefreshShortChanID() if err != nil { @@ -911,4 +921,14 @@ func TestRefreshShortChanID(t *testing.T) { "refreshed: want %v, got %v", state.ShortChanID(), pendingChannel.ShortChanID()) } + + // Check to ensure that the _other_ OpenChannel channel packager's + // source has also been updated after the refresh. This ensures that the + // other packagers will read and write to buckets corresponding to the + // updated short chan id. + if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + pendingChannel.Packager.(*ChannelPackager).source) + } }