htlcswitch: pass quit chans as unidirectional

This is a requirement for replacing the quit channel with a Context.
The Done() channel of a Context is always recv-only, so all users of
that channel must not expect a bidirectional channel.
This commit is contained in:
Jonathan Harvey-Buschel 2024-10-17 13:38:31 +02:00 committed by Oliver Gugger
parent afb7532f17
commit 753301cf38
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
7 changed files with 45 additions and 33 deletions

View File

@ -95,7 +95,7 @@ type InterceptableSwitch struct {
type interceptedPackets struct { type interceptedPackets struct {
packets []*htlcPacket packets []*htlcPacket
linkQuit chan struct{} linkQuit <-chan struct{}
isReplay bool isReplay bool
} }
@ -465,8 +465,8 @@ func (s *InterceptableSwitch) Resolve(res *FwdResolution) error {
// interceptor. If the interceptor signals the resume action, the htlcs are // interceptor. If the interceptor signals the resume action, the htlcs are
// forwarded to the switch. The link's quit signal should be provided to allow // forwarded to the switch. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown. // cancellation of forwarding during link shutdown.
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, isReplay bool, func (s *InterceptableSwitch) ForwardPackets(linkQuit <-chan struct{},
packets ...*htlcPacket) error { isReplay bool, packets ...*htlcPacket) error {
// Synchronize with the main event loop. This should be light in the // Synchronize with the main event loop. This should be light in the
// case where there is no interceptor. // case where there is no interceptor.

View File

@ -101,7 +101,7 @@ type ChannelLinkConfig struct {
// switch. The function returns and error in case it fails to send one or // switch. The function returns and error in case it fails to send one or
// more packets. The link's quit signal should be provided to allow // more packets. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown. // cancellation of forwarding during link shutdown.
ForwardPackets func(chan struct{}, bool, ...*htlcPacket) error ForwardPackets func(<-chan struct{}, bool, ...*htlcPacket) error
// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion // DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
// blobs, which are then used to inform how to forward an HTLC. // blobs, which are then used to inform how to forward an HTLC.

View File

@ -2197,17 +2197,21 @@ func newSingleLinkTestHarness(t *testing.T, chanAmt,
return nil return nil
} }
forwardPackets := func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {
return aliceSwitch.ForwardPackets(linkQuit, packets...)
}
// Instantiate with a long interval, so that we can precisely control // Instantiate with a long interval, so that we can precisely control
// the firing via force feeding. // the firing via force feeding.
bticker := ticker.NewForce(time.Hour) bticker := ticker.NewForce(time.Hour)
aliceCfg := ChannelLinkConfig{ aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy, FwrdingPolicy: globalPolicy,
Peer: alicePeer, Peer: alicePeer,
BestHeight: aliceSwitch.BestHeight, BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(), Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { ForwardPackets: forwardPackets,
return aliceSwitch.ForwardPackets(linkQuit, packets...)
},
DecodeHopIterators: decoder.DecodeHopIterators, DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) { hop.ErrorEncrypter, lnwire.FailCode) {
@ -4867,17 +4871,21 @@ func (h *persistentLinkHarness) restartLink(
return nil return nil
} }
forwardPackets := func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {
return h.hSwitch.ForwardPackets(linkQuit, packets...)
}
// Instantiate with a long interval, so that we can precisely control // Instantiate with a long interval, so that we can precisely control
// the firing via force feeding. // the firing via force feeding.
bticker := ticker.NewForce(time.Hour) bticker := ticker.NewForce(time.Hour)
aliceCfg := ChannelLinkConfig{ aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy, FwrdingPolicy: globalPolicy,
Peer: alicePeer, Peer: alicePeer,
BestHeight: h.hSwitch.BestHeight, BestHeight: h.hSwitch.BestHeight,
Circuits: h.hSwitch.CircuitModifier(), Circuits: h.hSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { ForwardPackets: forwardPackets,
return h.hSwitch.ForwardPackets(linkQuit, packets...)
},
DecodeHopIterators: decoder.DecodeHopIterators, DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) { hop.ErrorEncrypter, lnwire.FailCode) {
@ -7037,7 +7045,7 @@ func TestPipelineSettle(t *testing.T) {
// erroneously forwarded. If the forwardChan is closed before the last // erroneously forwarded. If the forwardChan is closed before the last
// step, then the test will fail. // step, then the test will fail.
forwardChan := make(chan struct{}) forwardChan := make(chan struct{})
fwdPkts := func(c chan struct{}, _ bool, hp ...*htlcPacket) error { fwdPkts := func(c <-chan struct{}, _ bool, hp ...*htlcPacket) error {
close(forwardChan) close(forwardChan)
return nil return nil
} }
@ -7223,7 +7231,7 @@ func TestChannelLinkShortFailureRelay(t *testing.T) {
aliceMsgs := mockPeer.sentMsgs aliceMsgs := mockPeer.sentMsgs
switchChan := make(chan *htlcPacket) switchChan := make(chan *htlcPacket)
coreLink.cfg.ForwardPackets = func(linkQuit chan struct{}, _ bool, coreLink.cfg.ForwardPackets = func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error { packets ...*htlcPacket) error {
for _, p := range packets { for _, p := range packets {

View File

@ -95,7 +95,7 @@ type mailBoxConfig struct {
// forwardPackets send a varidic number of htlcPackets to the switch to // forwardPackets send a varidic number of htlcPackets to the switch to
// be routed. A quit channel should be provided so that the call can // be routed. A quit channel should be provided so that the call can
// properly exit during shutdown. // properly exit during shutdown.
forwardPackets func(chan struct{}, ...*htlcPacket) error forwardPackets func(<-chan struct{}, ...*htlcPacket) error
// clock is a time source for the mailbox. // clock is a time source for the mailbox.
clock clock.Clock clock clock.Clock
@ -804,7 +804,7 @@ type mailOrchConfig struct {
// forwardPackets send a varidic number of htlcPackets to the switch to // forwardPackets send a varidic number of htlcPackets to the switch to
// be routed. A quit channel should be provided so that the call can // be routed. A quit channel should be provided so that the call can
// properly exit during shutdown. // properly exit during shutdown.
forwardPackets func(chan struct{}, ...*htlcPacket) error forwardPackets func(<-chan struct{}, ...*htlcPacket) error
// clock is a time source for the generated mailboxes. // clock is a time source for the generated mailboxes.
clock clock.Clock clock clock.Clock

View File

@ -250,7 +250,7 @@ func newMailboxContext(t *testing.T, startTime time.Time,
return ctx return ctx
} }
func (c *mailboxContext) forward(_ chan struct{}, func (c *mailboxContext) forward(_ <-chan struct{},
pkts ...*htlcPacket) error { pkts ...*htlcPacket) error {
for _, pkt := range pkts { for _, pkt := range pkts {
@ -706,7 +706,7 @@ func TestMailOrchestrator(t *testing.T) {
// First, we'll create a new instance of our orchestrator. // First, we'll create a new instance of our orchestrator.
mo := newMailOrchestrator(&mailOrchConfig{ mo := newMailOrchestrator(&mailOrchConfig{
failMailboxUpdate: failMailboxUpdate, failMailboxUpdate: failMailboxUpdate,
forwardPackets: func(_ chan struct{}, forwardPackets: func(_ <-chan struct{},
pkts ...*htlcPacket) error { pkts ...*htlcPacket) error {
return nil return nil

View File

@ -671,7 +671,7 @@ func (s *Switch) IsForwardedHTLC(chanID lnwire.ShortChannelID,
// given to forward them through the router. The sending link's quit channel is // given to forward them through the router. The sending link's quit channel is
// used to prevent deadlocks when the switch stops a link in the midst of // used to prevent deadlocks when the switch stops a link in the midst of
// forwarding. // forwarding.
func (s *Switch) ForwardPackets(linkQuit chan struct{}, func (s *Switch) ForwardPackets(linkQuit <-chan struct{},
packets ...*htlcPacket) error { packets ...*htlcPacket) error {
var ( var (
@ -849,7 +849,7 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) {
// receive a shutdown requuest. This method does not wait for a response from // receive a shutdown requuest. This method does not wait for a response from
// the htlcForwarder before returning. // the htlcForwarder before returning.
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
linkQuit chan struct{}) error { linkQuit <-chan struct{}) error {
command := &plexPacket{ command := &plexPacket{
pkt: packet, pkt: packet,

View File

@ -1142,15 +1142,19 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer,
return nil return nil
} }
forwardPackets := func(linkQuit <-chan struct{}, _ bool,
packets ...*htlcPacket) error {
return server.htlcSwitch.ForwardPackets(linkQuit, packets...)
}
link := NewChannelLink( link := NewChannelLink(
ChannelLinkConfig{ ChannelLinkConfig{
BestHeight: server.htlcSwitch.BestHeight, BestHeight: server.htlcSwitch.BestHeight,
FwrdingPolicy: h.globalPolicy, FwrdingPolicy: h.globalPolicy,
Peer: peer, Peer: peer,
Circuits: server.htlcSwitch.CircuitModifier(), Circuits: server.htlcSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error { ForwardPackets: forwardPackets,
return server.htlcSwitch.ForwardPackets(linkQuit, packets...)
},
DecodeHopIterators: decoder.DecodeHopIterators, DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) { hop.ErrorEncrypter, lnwire.FailCode) {