graph/db: refactor and clean-up

Refactor channelIDToBytes to return a slice instead of an 8 byte array
so that we dont need to use `[:]` everywhere.

Also make sure we are using this helper everywhere.
This commit is contained in:
Elle Mouton
2025-06-26 13:45:26 +02:00
parent 2310756307
commit 4bde8e2d04

View File

@ -1452,8 +1452,8 @@ func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
dbChans, err := db.GetPublicV1ChannelsBySCID( dbChans, err := db.GetPublicV1ChannelsBySCID(
ctx, sqlc.GetPublicV1ChannelsBySCIDParams{ ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
StartScid: chanIDStart[:], StartScid: chanIDStart,
EndScid: chanIDEnd[:], EndScid: chanIDEnd,
}, },
) )
if err != nil { if err != nil {
@ -1560,7 +1560,7 @@ func (s *SQLStore) MarkEdgeZombie(chanID uint64,
return db.UpsertZombieChannel( return db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{ ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
Scid: chanIDB[:], Scid: chanIDB,
NodeKey1: pubKey1[:], NodeKey1: pubKey1[:],
NodeKey2: pubKey2[:], NodeKey2: pubKey2[:],
}, },
@ -1592,7 +1592,7 @@ func (s *SQLStore) MarkEdgeLive(chanID uint64) error {
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.DeleteZombieChannel( res, err := db.DeleteZombieChannel(
ctx, sqlc.DeleteZombieChannelParams{ ctx, sqlc.DeleteZombieChannelParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -1644,7 +1644,7 @@ func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte,
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
zombie, err := db.GetZombieChannel( zombie, err := db.GetZombieChannel(
ctx, sqlc.GetZombieChannelParams{ ctx, sqlc.GetZombieChannelParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -1723,7 +1723,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
row, err := db.GetChannelBySCIDWithPolicies( row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -1786,7 +1786,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool,
err = db.UpsertZombieChannel( err = db.UpsertZombieChannel(
ctx, sqlc.UpsertZombieChannelParams{ ctx, sqlc.UpsertZombieChannelParams{
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
Scid: chanIDB[:], Scid: chanIDB,
NodeKey1: nodeKey1[:], NodeKey1: nodeKey1[:],
NodeKey2: nodeKey2[:], NodeKey2: nodeKey2[:],
}, },
@ -1833,14 +1833,12 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
ctx = context.TODO() ctx = context.TODO()
edge *models.ChannelEdgeInfo edge *models.ChannelEdgeInfo
policy1, policy2 *models.ChannelEdgePolicy policy1, policy2 *models.ChannelEdgePolicy
chanIDB = channelIDToBytes(chanID)
) )
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
row, err := db.GetChannelBySCIDWithPolicies( row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -1849,7 +1847,7 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) (
// index. // index.
zombie, err := db.GetZombieChannel( zombie, err := db.GetZombieChannel(
ctx, sqlc.GetZombieChannelParams{ ctx, sqlc.GetZombieChannelParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -2033,13 +2031,11 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
} }
chanIDB := channelIDToBytes(chanID)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], chanID)
channel, err := db.GetChannelBySCID( channel, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{ ctx, sqlc.GetChannelBySCIDParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -2047,7 +2043,7 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
// Check if it is a zombie channel. // Check if it is a zombie channel.
isZombie, err = db.IsZombieChannel( isZombie, err = db.IsZombieChannel(
ctx, sqlc.IsZombieChannelParams{ ctx, sqlc.IsZombieChannelParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -2179,15 +2175,14 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
) )
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
for _, chanID := range chanIDs { for _, chanID := range chanIDs {
var chanIDB [8]byte chanIDB := channelIDToBytes(chanID)
byteOrder.PutUint64(chanIDB[:], chanID)
// TODO(elle): potentially optimize this by using // TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and // sqlc.slice() once that works for both SQLite and
// Postgres. // Postgres.
row, err := db.GetChannelBySCIDWithPolicies( row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -2270,8 +2265,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
for _, chanInfo := range chansInfo { for _, chanInfo := range chansInfo {
channelID := chanInfo.ShortChannelID.ToUint64() channelID := chanInfo.ShortChannelID.ToUint64()
var chanIDB [8]byte chanIDB := channelIDToBytes(channelID)
byteOrder.PutUint64(chanIDB[:], channelID)
// TODO(elle): potentially optimize this by using // TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and // sqlc.slice() once that works for both SQLite and
@ -2279,7 +2273,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
_, err := db.GetChannelBySCID( _, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{ ctx, sqlc.GetChannelBySCIDParams{
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
Scid: chanIDB[:], Scid: chanIDB,
}, },
) )
if err == nil { if err == nil {
@ -2291,7 +2285,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64,
isZombie, err := db.IsZombieChannel( isZombie, err := db.IsZombieChannel(
ctx, sqlc.IsZombieChannelParams{ ctx, sqlc.IsZombieChannelParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -2609,18 +2603,16 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
endShortChanID = aliasmgr.StartingAlias endShortChanID = aliasmgr.StartingAlias
removedChans []*models.ChannelEdgeInfo removedChans []*models.ChannelEdgeInfo
)
var chanIDStart [8]byte chanIDStart = channelIDToBytes(startShortChanID.ToUint64())
byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) chanIDEnd = channelIDToBytes(endShortChanID.ToUint64())
var chanIDEnd [8]byte )
byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
rows, err := db.GetChannelsBySCIDRange( rows, err := db.GetChannelsBySCIDRange(
ctx, sqlc.GetChannelsBySCIDRangeParams{ ctx, sqlc.GetChannelsBySCIDRangeParams{
StartScid: chanIDStart[:], StartScid: chanIDStart,
EndScid: chanIDEnd[:], EndScid: chanIDEnd,
}, },
) )
if err != nil { if err != nil {
@ -2688,7 +2680,7 @@ func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID,
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.AddV1ChannelProof( res, err := db.AddV1ChannelProof(
ctx, sqlc.AddV1ChannelProofParams{ ctx, sqlc.AddV1ChannelProofParams{
Scid: scidBytes[:], Scid: scidBytes,
Node1Signature: proof.NodeSig1Bytes, Node1Signature: proof.NodeSig1Bytes,
Node2Signature: proof.NodeSig2Bytes, Node2Signature: proof.NodeSig2Bytes,
Bitcoin1Signature: proof.BitcoinSig1Bytes, Bitcoin1Signature: proof.BitcoinSig1Bytes,
@ -2734,7 +2726,7 @@ func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error {
) )
return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
return db.InsertClosedChannel(ctx, chanIDB[:]) return db.InsertClosedChannel(ctx, chanIDB)
}, sqldb.NoOpReset) }, sqldb.NoOpReset)
} }
@ -2750,7 +2742,7 @@ func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) {
) )
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var err error var err error
isClosed, err = db.IsClosedChannel(ctx, chanIDB[:]) isClosed, err = db.IsClosedChannel(ctx, chanIDB)
if err != nil { if err != nil {
return fmt.Errorf("unable to fetch closed channel: %w", return fmt.Errorf("unable to fetch closed channel: %w",
err) err)
@ -3077,7 +3069,7 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries,
// abort the transaction which would abort the entire batch. // abort the transaction which would abort the entire batch.
dbChan, err := tx.GetChannelAndNodesBySCID( dbChan, err := tx.GetChannelAndNodesBySCID(
ctx, sqlc.GetChannelAndNodesBySCIDParams{ ctx, sqlc.GetChannelAndNodesBySCIDParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -3779,7 +3771,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
// batch of transactions. // batch of transactions.
_, err := db.GetChannelBySCID( _, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{ ctx, sqlc.GetChannelBySCIDParams{
Scid: chanIDB[:], Scid: chanIDB,
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
}, },
) )
@ -3808,7 +3800,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
createParams := sqlc.CreateChannelParams{ createParams := sqlc.CreateChannelParams{
Version: int16(ProtocolV1), Version: int16(ProtocolV1),
Scid: chanIDB[:], Scid: chanIDB,
NodeID1: node1DBID, NodeID1: node1DBID,
NodeID2: node2DBID, NodeID2: node2DBID,
Outpoint: edge.ChannelPoint.String(), Outpoint: edge.ChannelPoint.String(),
@ -4455,9 +4447,9 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
// channelIDToBytes converts a channel ID (SCID) to a byte array // channelIDToBytes converts a channel ID (SCID) to a byte array
// representation. // representation.
func channelIDToBytes(channelID uint64) [8]byte { func channelIDToBytes(channelID uint64) []byte {
var chanIDB [8]byte var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], channelID) byteOrder.PutUint64(chanIDB[:], channelID)
return chanIDB return chanIDB[:]
} }