graph/db: refactor and clean-up

Refactor channelIDToBytes to return a slice instead of an 8 byte array
so that we dont need to use `[:]` everywhere.

Also make sure we are using this helper everywhere.
This commit is contained in:
Elle Mouton
2025-06-26 13:45:26 +02:00
parent 2310756307
commit 4bde8e2d04

View File

@ -1452,8 +1452,8 @@ func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
dbChans, err := db.GetPublicV1ChannelsBySCID(
ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
StartScid: chanIDStart[:],
EndScid: chanIDEnd[:],
StartScid: chanIDStart,
EndScid: chanIDEnd,
},
)
if err != nil {
@ -1560,7 +1560,7 @@ func (s *SQLStore) MarkEdgeZombie(chanID uint64,
return db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
Scid: chanIDB,
NodeKey1: pubKey1[:],
NodeKey2: pubKey2[:],
},
@ -1592,7 +1592,7 @@ func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.DeleteZombieChannel(
ctx, sqlc.DeleteZombieChannelParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -1644,7 +1644,7 @@ func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
zombie, err := db.GetZombieChannel(
ctx, sqlc.GetZombieChannelParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -1723,7 +1723,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -1786,7 +1786,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
err = db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
Scid: chanIDB,
NodeKey1: nodeKey1[:],
NodeKey2: nodeKey2[:],
},
@ -1833,14 +1833,12 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
ctx = context.TODO()
edge *models.ChannelEdgeInfo
policy1, policy2 *models.ChannelEdgePolicy
chanIDB = channelIDToBytes(chanID)
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -1849,7 +1847,7 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
// index.
zombie, err := db.GetZombieChannel(
ctx, sqlc.GetZombieChannelParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -2033,13 +2031,11 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
}
chanIDB := channelIDToBytes(chanID)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
channel, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -2047,7 +2043,7 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
// Check if it is a zombie channel.
isZombie, err = db.IsZombieChannel(
ctx, sqlc.IsZombieChannelParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -2179,15 +2175,14 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
for _, chanID := range chanIDs {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
chanIDB := channelIDToBytes(chanID)
// TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and
// Postgres.
row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -2270,8 +2265,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
for _, chanInfo := range chansInfo {
channelID := chanInfo.ShortChannelID.ToUint64()
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], channelID)
chanIDB := channelIDToBytes(channelID)
// TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and
@ -2279,7 +2273,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
_, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
Scid: chanIDB,
},
)
if err == nil {
@ -2291,7 +2285,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
isZombie, err := db.IsZombieChannel(
ctx, sqlc.IsZombieChannelParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -2609,18 +2603,16 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
endShortChanID = aliasmgr.StartingAlias
removedChans []*models.ChannelEdgeInfo
)
var chanIDStart [8]byte
byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
var chanIDEnd [8]byte
byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
chanIDEnd = channelIDToBytes(endShortChanID.ToUint64())
)
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
rows, err := db.GetChannelsBySCIDRange(
ctx, sqlc.GetChannelsBySCIDRangeParams{
StartScid: chanIDStart[:],
EndScid: chanIDEnd[:],
StartScid: chanIDStart,
EndScid: chanIDEnd,
},
)
if err != nil {
@ -2688,7 +2680,7 @@ func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.AddV1ChannelProof(
ctx, sqlc.AddV1ChannelProofParams{
Scid: scidBytes[:],
Scid: scidBytes,
Node1Signature: proof.NodeSig1Bytes,
Node2Signature: proof.NodeSig2Bytes,
Bitcoin1Signature: proof.BitcoinSig1Bytes,
@ -2734,7 +2726,7 @@ func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
)
return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
return db.InsertClosedChannel(ctx, chanIDB[:])
return db.InsertClosedChannel(ctx, chanIDB)
}, sqldb.NoOpReset)
}
@ -2750,7 +2742,7 @@ func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var err error
isClosed, err = db.IsClosedChannel(ctx, chanIDB[:])
isClosed, err = db.IsClosedChannel(ctx, chanIDB)
if err != nil {
return fmt.Errorf("unable to fetch closed channel: %w",
err)
@ -3077,7 +3069,7 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
// abort the transaction which would abort the entire batch.
dbChan, err := tx.GetChannelAndNodesBySCID(
ctx, sqlc.GetChannelAndNodesBySCIDParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -3779,7 +3771,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
// batch of transactions.
_, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{
Scid: chanIDB[:],
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
@ -3808,7 +3800,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
createParams := sqlc.CreateChannelParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
Scid: chanIDB,
NodeID1: node1DBID,
NodeID2: node2DBID,
Outpoint: edge.ChannelPoint.String(),
@ -4455,9 +4447,9 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
// channelIDToBytes converts a channel ID (SCID) to a byte array
// representation.
func channelIDToBytes(channelID uint64) [8]byte {
func channelIDToBytes(channelID uint64) []byte {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], channelID)
return chanIDB
return chanIDB[:]
}