htlcswitch: skip checking replays for reforwarded packets

We now rely on the forwarding package's state to decide whether a given
packet is a reforwarding or not. If we know it's a reforwarding packet,
there's no need to check for replays in the `sharedHashes` bucket, which
behaves the same as if we are querying the `batchReplayBkt`.
This commit is contained in:
yyforyongyu
2025-06-11 12:16:09 +08:00
parent a5c4a7c547
commit fb95458a1b
3 changed files with 19 additions and 15 deletions

View File

@@ -745,7 +745,8 @@ func (r *DecodeHopIteratorResponse) Result() (Iterator, lnwire.FailCode) {
// the presented readers and rhashes *NEVER* deviate across invocations for the // the presented readers and rhashes *NEVER* deviate across invocations for the
// same id. // same id.
func (p *OnionProcessor) DecodeHopIterators(id []byte, func (p *OnionProcessor) DecodeHopIterators(id []byte,
reqs []DecodeHopIteratorRequest) ([]DecodeHopIteratorResponse, error) { reqs []DecodeHopIteratorRequest,
reforward bool) ([]DecodeHopIteratorResponse, error) {
var ( var (
batchSize = len(reqs) batchSize = len(reqs)
@@ -864,11 +865,12 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte,
continue continue
} }
// If this index is contained in the replay set, mark it with a // If this index is contained in the replay set, and it is not a
// temporary channel failure error code. We infer that the // reforwarding on startup, mark it with a permanent channel
// offending error was due to a replayed packet because this // failure error code. We infer that the offending error was due
// index was found in the replay set. // to a replayed packet because this index was found in the
if replays.Contains(uint16(i)) { // replay set.
if !reforward && replays.Contains(uint16(i)) {
log.Errorf("unable to process onion packet: %v", log.Errorf("unable to process onion packet: %v",
sphinx.ErrReplayedPacket) sphinx.ErrReplayedPacket)

View File

@@ -108,8 +108,10 @@ type ChannelLinkConfig struct {
// blobs, which are then used to inform how to forward an HTLC. // blobs, which are then used to inform how to forward an HTLC.
// //
// NOTE: This function assumes the same set of readers and preimages // NOTE: This function assumes the same set of readers and preimages
// are always presented for the same identifier. // are always presented for the same identifier. The last boolean is
DecodeHopIterators func([]byte, []hop.DecodeHopIteratorRequest) ( // used to decide whether this is a reforwarding or not - when it's
// reforwarding, we skip the replay check enforced in our decay log.
DecodeHopIterators func([]byte, []hop.DecodeHopIteratorRequest, bool) (
[]hop.DecodeHopIteratorResponse, error) []hop.DecodeHopIteratorResponse, error)
// ExtractErrorEncrypter function is responsible for decoding HTLC // ExtractErrorEncrypter function is responsible for decoding HTLC
@@ -3764,12 +3766,14 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) {
} }
} }
reforward := fwdPkg.State != channeldb.FwdStateLockedIn
// Atomically decode the incoming htlcs, simultaneously checking for // Atomically decode the incoming htlcs, simultaneously checking for
// replay attempts. A particular index in the returned, spare list of // replay attempts. A particular index in the returned, spare list of
// channel iterators should only be used if the failure code at the // channel iterators should only be used if the failure code at the
// same index is lnwire.FailCodeNone. // same index is lnwire.FailCodeNone.
decodeResps, sphinxErr := l.cfg.DecodeHopIterators( decodeResps, sphinxErr := l.cfg.DecodeHopIterators(
fwdPkg.ID(), decodeReqs, fwdPkg.ID(), decodeReqs, reforward,
) )
if sphinxErr != nil { if sphinxErr != nil {
l.failf(LinkFailureError{code: ErrInternalError}, l.failf(LinkFailureError{code: ErrInternalError},
@@ -4120,17 +4124,15 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg) {
return return
} }
replay := fwdPkg.State != channeldb.FwdStateLockedIn l.log.Debugf("forwarding %d packets to switch: reforward=%v",
len(switchPackets), reforward)
l.log.Debugf("forwarding %d packets to switch: replay=%v",
len(switchPackets), replay)
// NOTE: This call is made synchronous so that we ensure all circuits // NOTE: This call is made synchronous so that we ensure all circuits
// are committed in the exact order that they are processed in the link. // are committed in the exact order that they are processed in the link.
// Failing to do this could cause reorderings/gaps in the range of // Failing to do this could cause reorderings/gaps in the range of
// opened circuits, which violates assumptions made by the circuit // opened circuits, which violates assumptions made by the circuit
// trimming. // trimming.
l.forwardBatch(replay, switchPackets...) l.forwardBatch(reforward, switchPackets...)
} }
// experimentalEndorsement returns the value to set for our outgoing // experimentalEndorsement returns the value to set for our outgoing

View File

@@ -522,7 +522,7 @@ func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
} }
func (p *mockIteratorDecoder) DecodeHopIterators(id []byte, func (p *mockIteratorDecoder) DecodeHopIterators(id []byte,
reqs []hop.DecodeHopIteratorRequest) ( reqs []hop.DecodeHopIteratorRequest, _ bool) (
[]hop.DecodeHopIteratorResponse, error) { []hop.DecodeHopIteratorResponse, error) {
idHash := sha256.Sum256(id) idHash := sha256.Sum256(id)