diff --git a/accessman.go b/accessman.go index 9667fbc2e..0d1eda022 100644 --- a/accessman.go +++ b/accessman.go @@ -35,6 +35,10 @@ type accessMan struct { // is the string-version of the serialized public key. // // NOTE: This MUST be accessed with the banScoreMtx held. + // + // TODO(yy): unify `peerScores` and `peerCounts` - there's no need to + // create two maps tracking essentially the same info. `numRestricted` + // can also be derived from `peerCounts`. peerScores map[string]peerSlotStatus // numRestricted tracks the number of peers with restricted access in @@ -77,6 +81,48 @@ func newAccessMan(cfg *accessManConfig) (*accessMan, error) { return a, nil } +// hasPeer checks whether a given peer already exists in the internal maps. +func (a *accessMan) hasPeer(ctx context.Context, + pub string) (peerAccessStatus, bool) { + + // Lock banScoreMtx for reading so that we can read the banning maps + // below. + a.banScoreMtx.RLock() + defer a.banScoreMtx.RUnlock() + + count, found := a.peerCounts[pub] + if found { + if count.HasOpenOrClosedChan { + acsmLog.DebugS(ctx, "Peer has open/closed channel, "+ + "assigning protected access") + + // Exit early if the peer is no longer restricted. + return peerStatusProtected, true + } + + if count.PendingOpenCount != 0 { + acsmLog.DebugS(ctx, "Peer has pending channel(s), "+ + "assigning temporary access") + + // Exit early if the peer is no longer restricted. + return peerStatusTemporary, true + } + + return peerStatusRestricted, true + } + + // Check if the peer is found in the scores map. + status, found := a.peerScores[pub] + if found { + acsmLog.DebugS(ctx, "Peer already has access", "access", + status.state) + + return status.state, true + } + + return peerStatusRestricted, false +} + // assignPeerPerms assigns a new peer its permissions. This does not track the // access in the maps. This is intentional. func (a *accessMan) assignPeerPerms(remotePub *btcec.PublicKey) ( @@ -91,31 +137,15 @@ func (a *accessMan) assignPeerPerms(remotePub *btcec.PublicKey) ( acsmLog.DebugS(ctx, "Assigning permissions") // Default is restricted unless the below filters say otherwise. - access := peerStatusRestricted + access, peerExist := a.hasPeer(ctx, peerMapKey) - // Lock banScoreMtx for reading so that we can update the banning maps - // below. - a.banScoreMtx.RLock() - if count, found := a.peerCounts[peerMapKey]; found { - if count.HasOpenOrClosedChan { - acsmLog.DebugS(ctx, "Peer has open/closed channel, "+ - "assigning protected access") - - access = peerStatusProtected - } else if count.PendingOpenCount != 0 { - acsmLog.DebugS(ctx, "Peer has pending channel(s), "+ - "assigning temporary access") - - access = peerStatusTemporary - } - } - a.banScoreMtx.RUnlock() - - // Exit early if the peer status is no longer restricted. + // Exit early if the peer is not restricted. if access != peerStatusRestricted { return access, nil } + // If we are here, it means the peer has peerStatusRestricted. + // // Check whether this peer is banned. shouldDisconnect, err := a.cfg.shouldDisconnect(remotePub) if err != nil { @@ -138,6 +168,12 @@ func (a *accessMan) assignPeerPerms(remotePub *btcec.PublicKey) ( // peer. acsmLog.DebugS(ctx, "Peer has no channels, assigning restricted access") + // If this is an existing peer, there's no need to check for slot limit. + if peerExist { + acsmLog.DebugS(ctx, "Skipped slot check for existing peer") + return access, nil + } + a.banScoreMtx.RLock() defer a.banScoreMtx.RUnlock() diff --git a/accessman_test.go b/accessman_test.go index 0663a9b4e..b67d4f690 100644 --- a/accessman_test.go +++ b/accessman_test.go @@ -1,6 +1,7 @@ package lnd import ( + "context" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -250,9 +251,8 @@ func TestAssignPeerPerms(t *testing.T) { expectedErr: ErrGossiperBan, }, // peer6 has no channel with us, and we expect it to have a - // restricted status. We also expect the error - // `ErrNoMoreRestrictedAccessSlots` to be returned given - // we only allow 1 restricted peer in this test. + // restricted status. Since this peer is seen, we don't expect + // the error `ErrNoMoreRestrictedAccessSlots` to be returned. { name: "peer with no channels and restricted", peerPub: genPeerPub(), @@ -264,7 +264,7 @@ func TestAssignPeerPerms(t *testing.T) { numRestricted: 1, expectedStatus: peerStatusRestricted, - expectedErr: ErrNoMoreRestrictedAccessSlots, + expectedErr: nil, }, } @@ -394,3 +394,135 @@ func TestAssignPeerPermsBypassRestriction(t *testing.T) { }) } } + +// TestAssignPeerPermsBypassExisting asserts that when the peer is a +// pre-existing peer, it won't be restricted. +func TestAssignPeerPermsBypassExisting(t *testing.T) { + t.Parallel() + + // genPeerPub is a helper closure that generates a random public key. + genPeerPub := func() *btcec.PublicKey { + peerPriv, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return peerPriv.PubKey() + } + + // peer1 exists in `peerCounts` map. + peer1 := genPeerPub() + peer1Str := string(peer1.SerializeCompressed()) + + // peer2 exists in `peerScores` map. + peer2 := genPeerPub() + peer2Str := string(peer2.SerializeCompressed()) + + // peer3 is a new peer. + peer3 := genPeerPub() + + // Create params to init the accessman. + initPerms := func() (map[string]channeldb.ChanCount, error) { + return map[string]channeldb.ChanCount{ + peer1Str: {}, + }, nil + } + + disconnect := func(*btcec.PublicKey) (bool, error) { + return false, nil + } + + cfg := &accessManConfig{ + initAccessPerms: initPerms, + shouldDisconnect: disconnect, + maxRestrictedSlots: 0, + } + + a, err := newAccessMan(cfg) + require.NoError(t, err) + + // Add peer2 to the `peerScores`. + a.peerScores[peer2Str] = peerSlotStatus{ + state: peerStatusTemporary, + } + + // Assigning to peer1 should not return an error. + status, err := a.assignPeerPerms(peer1) + require.NoError(t, err) + require.Equal(t, peerStatusRestricted, status) + + // Assigning to peer2 should not return an error. + status, err = a.assignPeerPerms(peer2) + require.NoError(t, err) + require.Equal(t, peerStatusTemporary, status) + + // Assigning to peer3 should return an error. + status, err = a.assignPeerPerms(peer3) + require.ErrorIs(t, err, ErrNoMoreRestrictedAccessSlots) + require.Equal(t, peerStatusRestricted, status) +} + +// TestHasPeer asserts `hasPeer` returns the correct results. +func TestHasPeer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create a testing accessMan. + a := &accessMan{ + peerCounts: make(map[string]channeldb.ChanCount), + peerScores: make(map[string]peerSlotStatus), + } + + // peer1 exists with an open channel. + peer1 := "peer1" + a.peerCounts[peer1] = channeldb.ChanCount{ + HasOpenOrClosedChan: true, + } + peer1Access := peerStatusProtected + + // peer2 exists with a pending channel. + peer2 := "peer2" + a.peerCounts[peer2] = channeldb.ChanCount{ + PendingOpenCount: 1, + } + peer2Access := peerStatusTemporary + + // peer3 exists without any channels. + peer3 := "peer3" + a.peerCounts[peer3] = channeldb.ChanCount{} + peer3Access := peerStatusRestricted + + // peer4 exists with a score. + peer4 := "peer4" + peer4Access := peerStatusTemporary + a.peerScores[peer4] = peerSlotStatus{state: peer4Access} + + // peer5 doesn't exist. + peer5 := "peer5" + + // We now assert `hasPeer` returns the correct results. + // + // peer1 should be found with peerStatusProtected. + access, found := a.hasPeer(ctx, peer1) + require.True(t, found) + require.Equal(t, peer1Access, access) + + // peer2 should be found with peerStatusTemporary. + access, found = a.hasPeer(ctx, peer2) + require.True(t, found) + require.Equal(t, peer2Access, access) + + // peer3 should be found with peerStatusRestricted. + access, found = a.hasPeer(ctx, peer3) + require.True(t, found) + require.Equal(t, peer3Access, access) + + // peer4 should be found with peerStatusTemporary. + access, found = a.hasPeer(ctx, peer4) + require.True(t, found) + require.Equal(t, peer4Access, access) + + // peer5 should NOT be found. + access, found = a.hasPeer(ctx, peer5) + require.False(t, found) + require.Equal(t, peerStatusRestricted, access) +}