diff --git a/channeldb/channel.go b/channeldb/channel.go index a16fa1c16..701040002 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -726,44 +726,16 @@ func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { c.Lock() defer c.Unlock() - var status ChannelStatus - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Add status LocalDataLoss to the existing bitvector found in - // the DB. - status = channel.chanStatus | ChanStatusLocalDataLoss - channel.chanStatus = status - - var b bytes.Buffer - if err := WriteElement(&b, commitPoint); err != nil { - return err - } - - err = chanBucket.Put(dataLossCommitPointKey, b.Bytes()) - if err != nil { - return err - } - - return putOpenChannel(chanBucket, channel) - }); err != nil { + var b bytes.Buffer + if err := WriteElement(&b, commitPoint); err != nil { return err } - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status + putCommitPoint := func(chanBucket *bbolt.Bucket) error { + return chanBucket.Put(dataLossCommitPointKey, b.Bytes()) + } - return nil + return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint) } // DataLossCommitPoint retrieves the stored commit point set during @@ -914,7 +886,12 @@ func (c *OpenChannel) MarkCommitmentBroadcasted() error { return c.putChanStatus(ChanStatusCommitBroadcasted) } -func (c *OpenChannel) putChanStatus(status ChannelStatus) error { +// putChanStatus appends the given status to the channel. fs is an optional +// list of closures that are given the chanBucket in order to atomically add +// extra information together with the new status. +func (c *OpenChannel) putChanStatus(status ChannelStatus, + fs ...func(*bbolt.Bucket) error) error { + if err := c.Db.Update(func(tx *bbolt.Tx) error { chanBucket, err := fetchChanBucket( tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, @@ -932,7 +909,17 @@ func (c *OpenChannel) putChanStatus(status ChannelStatus) error { status = channel.chanStatus | status channel.chanStatus = status - return putOpenChannel(chanBucket, channel) + if err := putOpenChannel(chanBucket, channel); err != nil { + return err + } + + for _, f := range fs { + if err := f(chanBucket); err != nil { + return err + } + } + + return nil }); err != nil { return err }