htlcswitch: don't pass pending update counts into quiescer

This change simplifies some of the quiescer responsibilities in
favor of making the link check whether or not it has a clean state
to be able to send or receive an stfu. This change was made on the
basis that the only use the quiescer makes of this information is
to assess that it is or is not zero. Further the difficulty of
checking this condition in the link is barely more burdensome than
selecting the proper information to pass to the quiescer anyway.
This commit is contained in:
Keagan McClelland
2024-11-12 15:06:02 -07:00
parent a4c49a88f1
commit ac0c24aa7b
3 changed files with 93 additions and 201 deletions

View File

@@ -76,7 +76,7 @@ type Quiescer interface {
InitStfu(req StfuReq)
// RecvStfu is called when we receive an Stfu message from the remote.
RecvStfu(stfu lnwire.Stfu, numRemotePendingUpdates uint64) error
RecvStfu(stfu lnwire.Stfu) error
// CanRecvUpdates returns true if we haven't yet received an Stfu which
// would mark the end of the remote's ability to send updates.
@@ -88,7 +88,7 @@ type Quiescer interface {
// SendOwedStfu sends Stfu if it owes one. It returns an error if the
// state machine is in an invalid state.
SendOwedStfu(numPendingLocalUpdates uint64) error
SendOwedStfu() error
// OnResume accepts a no return closure that will run when the quiescer
// is resumed.
@@ -175,19 +175,15 @@ func NewQuiescer(cfg QuiescerCfg) Quiescer {
}
// RecvStfu is called when we receive an Stfu message from the remote.
func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu,
numPendingRemoteUpdates uint64) error {
func (q *QuiescerLive) RecvStfu(msg lnwire.Stfu) error {
q.Lock()
defer q.Unlock()
return q.recvStfu(msg, numPendingRemoteUpdates)
return q.recvStfu(msg)
}
// recvStfu is called when we receive an Stfu message from the remote.
func (q *QuiescerLive) recvStfu(msg lnwire.Stfu,
numPendingRemoteUpdates uint64) error {
func (q *QuiescerLive) recvStfu(msg lnwire.Stfu) error {
// At the time of this writing, this check that we have already received
// an Stfu is not strictly necessary, according to the specification.
// However, it is fishy if we do and it is unclear how we should handle
@@ -203,7 +199,7 @@ func (q *QuiescerLive) recvStfu(msg lnwire.Stfu,
q.cfg.chanID)
}
if !q.canRecvStfu(numPendingRemoteUpdates) {
if !q.canRecvStfu() {
return fmt.Errorf("%w for channel %v", ErrPendingRemoteUpdates,
q.cfg.chanID)
}
@@ -228,26 +224,22 @@ func (q *QuiescerLive) recvStfu(msg lnwire.Stfu,
// MakeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *QuiescerLive) MakeStfu(
numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] {
func (q *QuiescerLive) MakeStfu() fn.Result[lnwire.Stfu] {
q.RLock()
defer q.RUnlock()
return q.makeStfu(numPendingLocalUpdates)
return q.makeStfu()
}
// makeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *QuiescerLive) makeStfu(
numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] {
func (q *QuiescerLive) makeStfu() fn.Result[lnwire.Stfu] {
if q.sent {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrStfuAlreadySent, q.cfg.chanID)
}
if !q.canSendStfu(numPendingLocalUpdates) {
if !q.canSendStfu() {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrPendingLocalUpdates, q.cfg.chanID)
}
@@ -380,44 +372,44 @@ func (q *QuiescerLive) CanSendStfu(numPendingLocalUpdates uint64) bool {
q.RLock()
defer q.RUnlock()
return q.canSendStfu(numPendingLocalUpdates)
return q.canSendStfu()
}
// canSendStfu returns true if we can send an Stfu.
func (q *QuiescerLive) canSendStfu(numPendingLocalUpdates uint64) bool {
return numPendingLocalUpdates == 0 && !q.sent
func (q *QuiescerLive) canSendStfu() bool {
return !q.sent
}
// CanRecvStfu returns true if we can receive an Stfu.
func (q *QuiescerLive) CanRecvStfu(numPendingRemoteUpdates uint64) bool {
func (q *QuiescerLive) CanRecvStfu() bool {
q.RLock()
defer q.RUnlock()
return q.canRecvStfu(numPendingRemoteUpdates)
return q.canRecvStfu()
}
// canRecvStfu returns true if we can receive an Stfu.
func (q *QuiescerLive) canRecvStfu(numPendingRemoteUpdates uint64) bool {
return numPendingRemoteUpdates == 0 && !q.received
func (q *QuiescerLive) canRecvStfu() bool {
return !q.received
}
// SendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *QuiescerLive) SendOwedStfu(numPendingLocalUpdates uint64) error {
func (q *QuiescerLive) SendOwedStfu() error {
q.Lock()
defer q.Unlock()
return q.sendOwedStfu(numPendingLocalUpdates)
return q.sendOwedStfu()
}
// sendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *QuiescerLive) sendOwedStfu(numPendingLocalUpdates uint64) error {
if !q.oweStfu() || !q.canSendStfu(numPendingLocalUpdates) {
func (q *QuiescerLive) sendOwedStfu() error {
if !q.oweStfu() || !q.canSendStfu() {
return nil
}
err := q.makeStfu(numPendingLocalUpdates).Sink(q.cfg.sendMsg)
err := q.makeStfu().Sink(q.cfg.sendMsg)
if err == nil {
q.sent = true
@@ -561,13 +553,13 @@ var _ Quiescer = (*quiescerNoop)(nil)
func (q *quiescerNoop) InitStfu(req StfuReq) {
req.Resolve(fn.Errf[lntypes.ChannelParty]("quiescence not supported"))
}
func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu, _ uint64) error { return nil }
func (q *quiescerNoop) CanRecvUpdates() bool { return true }
func (q *quiescerNoop) CanSendUpdates() bool { return true }
func (q *quiescerNoop) SendOwedStfu(_ uint64) error { return nil }
func (q *quiescerNoop) IsQuiescent() bool { return false }
func (q *quiescerNoop) OnResume(hook func()) { hook() }
func (q *quiescerNoop) Resume() {}
func (q *quiescerNoop) RecvStfu(_ lnwire.Stfu) error { return nil }
func (q *quiescerNoop) CanRecvUpdates() bool { return true }
func (q *quiescerNoop) CanSendUpdates() bool { return true }
func (q *quiescerNoop) SendOwedStfu() error { return nil }
func (q *quiescerNoop) IsQuiescent() bool { return false }
func (q *quiescerNoop) OnResume(hook func()) { hook() }
func (q *quiescerNoop) Resume() {}
func (q *quiescerNoop) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] {
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
}