diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index cd3090a92..75866beba 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -135,14 +135,16 @@ type ChannelUpdateHandler interface { MayAddOutgoingHtlc(lnwire.MilliSatoshi) error // EnableAdds sets the ChannelUpdateHandler state to allow - // UpdateAddHtlc's in the specified direction. It returns an error if - // the state already allowed those adds. - EnableAdds(direction LinkDirection) error + // UpdateAddHtlc's in the specified direction. It returns true if the + // state was changed and false if the desired state was already set + // before the method was called. + EnableAdds(direction LinkDirection) bool // DisableAdds sets the ChannelUpdateHandler state to allow - // UpdateAddHtlc's in the specified direction. It returns an error if - // the state already disallowed those adds. - DisableAdds(direction LinkDirection) error + // UpdateAddHtlc's in the specified direction. It returns true if the + // state was changed and false if the desired state was already set + // before the method was called. + DisableAdds(direction LinkDirection) bool // IsFlushing returns true when UpdateAddHtlc's are disabled in the // direction of the argument. diff --git a/htlcswitch/link.go b/htlcswitch/link.go index ab67e736b..29532f572 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -618,41 +618,25 @@ func (l *channelLink) EligibleToUpdate() bool { } // EnableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in -// the specified direction. It returns an error if the state already allowed -// those adds. -func (l *channelLink) EnableAdds(linkDirection LinkDirection) error { +// the specified direction. It returns true if the state was changed and false +// if the desired state was already set before the method was called. +func (l *channelLink) EnableAdds(linkDirection LinkDirection) bool { if linkDirection == Outgoing { - if !l.isOutgoingAddBlocked.Swap(false) { - return errors.New("outgoing adds already enabled") - } + return l.isOutgoingAddBlocked.Swap(false) } - if linkDirection == Incoming { - if !l.isIncomingAddBlocked.Swap(false) { - return errors.New("incoming adds already enabled") - } - } - - return nil + return l.isIncomingAddBlocked.Swap(false) } // DisableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in -// the specified direction. It returns an error if the state already disallowed -// those adds. -func (l *channelLink) DisableAdds(linkDirection LinkDirection) error { +// the specified direction. It returns true if the state was changed and false +// if the desired state was already set before the method was called. +func (l *channelLink) DisableAdds(linkDirection LinkDirection) bool { if linkDirection == Outgoing { - if l.isOutgoingAddBlocked.Swap(true) { - return errors.New("outgoing adds already disabled") - } + return !l.isOutgoingAddBlocked.Swap(true) } - if linkDirection == Incoming { - if l.isIncomingAddBlocked.Swap(true) { - return errors.New("incoming adds already disabled") - } - } - - return nil + return !l.isIncomingAddBlocked.Swap(true) } // IsFlushing returns true when UpdateAddHtlc's are disabled in the direction of diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index cac26a6de..d9e583876 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6969,27 +6969,22 @@ func TestLinkFlushApiDirectionIsolation(t *testing.T) { for i := 0; i < 10; i++ { if prand.Uint64()%2 == 0 { - //nolint:errcheck aliceLink.EnableAdds(Outgoing) require.False(t, aliceLink.IsFlushing(Outgoing)) } else { - //nolint:errcheck aliceLink.DisableAdds(Outgoing) require.True(t, aliceLink.IsFlushing(Outgoing)) } require.False(t, aliceLink.IsFlushing(Incoming)) } - //nolint:errcheck aliceLink.EnableAdds(Outgoing) for i := 0; i < 10; i++ { if prand.Uint64()%2 == 0 { - //nolint:errcheck aliceLink.EnableAdds(Incoming) require.False(t, aliceLink.IsFlushing(Incoming)) } else { - //nolint:errcheck aliceLink.DisableAdds(Incoming) require.True(t, aliceLink.IsFlushing(Incoming)) } @@ -7010,16 +7005,16 @@ func TestLinkFlushApiGateStateIdempotence(t *testing.T) { ) for _, dir := range []LinkDirection{Incoming, Outgoing} { - require.Nil(t, aliceLink.DisableAdds(dir)) + require.True(t, aliceLink.DisableAdds(dir)) require.True(t, aliceLink.IsFlushing(dir)) - require.NotNil(t, aliceLink.DisableAdds(dir)) + require.False(t, aliceLink.DisableAdds(dir)) require.True(t, aliceLink.IsFlushing(dir)) - require.Nil(t, aliceLink.EnableAdds(dir)) + require.True(t, aliceLink.EnableAdds(dir)) require.False(t, aliceLink.IsFlushing(dir)) - require.NotNil(t, aliceLink.EnableAdds(dir)) + require.False(t, aliceLink.EnableAdds(dir)) require.False(t, aliceLink.IsFlushing(dir)) } } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 5c7722a54..ab6fbe76a 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -906,13 +906,14 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) { return f.shortChanID, nil } -func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) error { +func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) bool { // TODO(proofofkeags): Implement - return nil + return true } -func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) error { + +func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) bool { // TODO(proofofkeags): Implement - return nil + return true } func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool { // TODO(proofofkeags): Implement diff --git a/peer/brontide.go b/peer/brontide.go index 7587d513c..22841df11 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -2991,9 +2991,8 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { } link.OnCommitOnce(htlcswitch.Outgoing, func() { - err := link.DisableAdds(htlcswitch.Outgoing) - if err != nil { - p.log.Warnf("outgoing link adds already "+ + if !link.DisableAdds(htlcswitch.Outgoing) { + p.log.Warnf("Outgoing link adds already "+ "disabled: %v", link.ChanID()) } @@ -3619,12 +3618,9 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { switch typed := msg.msg.(type) { case *lnwire.Shutdown: // Disable incoming adds immediately. - if link != nil { - err := link.DisableAdds(htlcswitch.Incoming) - if err != nil { - p.log.Warnf("incoming link adds already "+ - "disabled: %v", link.ChanID()) - } + if link != nil && !link.DisableAdds(htlcswitch.Incoming) { + p.log.Warnf("Incoming link adds already disabled: %v", + link.ChanID()) } oShutdown, err := chanCloser.ReceiveShutdown(*typed) @@ -3646,10 +3642,12 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { // next time we send a CommitSig to remain spec // compliant. link.OnCommitOnce(htlcswitch.Outgoing, func() { - err := link.DisableAdds(htlcswitch.Outgoing) - if err != nil { - p.log.Warn(err.Error()) + if !link.DisableAdds(htlcswitch.Outgoing) { + p.log.Warnf("Outgoing link adds "+ + "already disabled: %v", + link.ChanID()) } + p.queueMsg(&msg, nil) }) }) diff --git a/peer/test_utils.go b/peer/test_utils.go index e4b5d6086..05bfe6ad4 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -4,7 +4,6 @@ import ( "bytes" crand "crypto/rand" "encoding/binary" - "fmt" "io" "math/rand" "net" @@ -510,34 +509,20 @@ type mockMessageConn struct { readRaceDetectingCounter int } -func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) error { - switch dir { - case htlcswitch.Outgoing: - if !m.isOutgoingAddBlocked.Swap(false) { - return fmt.Errorf("%v adds already enabled", dir) - } - case htlcswitch.Incoming: - if !m.isIncomingAddBlocked.Swap(false) { - return fmt.Errorf("%v adds already enabled", dir) - } +func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) bool { + if dir == htlcswitch.Outgoing { + return m.isOutgoingAddBlocked.Swap(false) } - return nil + return m.isIncomingAddBlocked.Swap(false) } -func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) error { - switch dir { - case htlcswitch.Outgoing: - if m.isOutgoingAddBlocked.Swap(true) { - return fmt.Errorf("%v adds already disabled", dir) - } - case htlcswitch.Incoming: - if m.isIncomingAddBlocked.Swap(true) { - return fmt.Errorf("%v adds already disabled", dir) - } +func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) bool { + if dir == htlcswitch.Outgoing { + return !m.isOutgoingAddBlocked.Swap(true) } - return nil + return !m.isIncomingAddBlocked.Swap(true) } func (m *mockUpdateHandler) IsFlushing(dir htlcswitch.LinkDirection) bool {