diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index c26312f0d..4e05d7009 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1452,8 +1452,8 @@ func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32, err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { dbChans, err := db.GetPublicV1ChannelsBySCID( ctx, sqlc.GetPublicV1ChannelsBySCIDParams{ - StartScid: chanIDStart[:], - EndScid: chanIDEnd[:], + StartScid: chanIDStart, + EndScid: chanIDEnd, }, ) if err != nil { @@ -1560,7 +1560,7 @@ func (s *SQLStore) MarkEdgeZombie(chanID uint64, return db.UpsertZombieChannel( ctx, sqlc.UpsertZombieChannelParams{ Version: int16(ProtocolV1), - Scid: chanIDB[:], + Scid: chanIDB, NodeKey1: pubKey1[:], NodeKey2: pubKey2[:], }, @@ -1592,7 +1592,7 @@ func (s *SQLStore) MarkEdgeLive(chanID uint64) error { err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { res, err := db.DeleteZombieChannel( ctx, sqlc.DeleteZombieChannelParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -1644,7 +1644,7 @@ func (s *SQLStore) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte, err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { zombie, err := db.GetZombieChannel( ctx, sqlc.GetZombieChannelParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -1723,7 +1723,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, row, err := db.GetChannelBySCIDWithPolicies( ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -1786,7 +1786,7 @@ func (s *SQLStore) DeleteChannelEdges(strictZombiePruning, markZombie bool, err = db.UpsertZombieChannel( ctx, sqlc.UpsertZombieChannelParams{ Version: int16(ProtocolV1), - Scid: chanIDB[:], + Scid: chanIDB, NodeKey1: nodeKey1[:], NodeKey2: nodeKey2[:], }, @@ -1833,14 +1833,12 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) ( ctx = context.TODO() edge *models.ChannelEdgeInfo policy1, policy2 *models.ChannelEdgePolicy + chanIDB = channelIDToBytes(chanID) ) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - var chanIDB [8]byte - byteOrder.PutUint64(chanIDB[:], chanID) - row, err := db.GetChannelBySCIDWithPolicies( ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -1849,7 +1847,7 @@ func (s *SQLStore) FetchChannelEdgesByID(chanID uint64) ( // index. zombie, err := db.GetZombieChannel( ctx, sqlc.GetZombieChannelParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -2033,13 +2031,11 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, return node1LastUpdate, node2LastUpdate, exists, isZombie, nil } + chanIDB := channelIDToBytes(chanID) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - var chanIDB [8]byte - byteOrder.PutUint64(chanIDB[:], chanID) - channel, err := db.GetChannelBySCID( ctx, sqlc.GetChannelBySCIDParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -2047,7 +2043,7 @@ func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool, // Check if it is a zombie channel. isZombie, err = db.IsZombieChannel( ctx, sqlc.IsZombieChannelParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -2179,15 +2175,14 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { ) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { for _, chanID := range chanIDs { - var chanIDB [8]byte - byteOrder.PutUint64(chanIDB[:], chanID) + chanIDB := channelIDToBytes(chanID) // TODO(elle): potentially optimize this by using // sqlc.slice() once that works for both SQLite and // Postgres. row, err := db.GetChannelBySCIDWithPolicies( ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -2270,8 +2265,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { for _, chanInfo := range chansInfo { channelID := chanInfo.ShortChannelID.ToUint64() - var chanIDB [8]byte - byteOrder.PutUint64(chanIDB[:], channelID) + chanIDB := channelIDToBytes(channelID) // TODO(elle): potentially optimize this by using // sqlc.slice() once that works for both SQLite and @@ -2279,7 +2273,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, _, err := db.GetChannelBySCID( ctx, sqlc.GetChannelBySCIDParams{ Version: int16(ProtocolV1), - Scid: chanIDB[:], + Scid: chanIDB, }, ) if err == nil { @@ -2291,7 +2285,7 @@ func (s *SQLStore) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo) ([]uint64, isZombie, err := db.IsZombieChannel( ctx, sqlc.IsZombieChannelParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -2609,18 +2603,16 @@ func (s *SQLStore) DisconnectBlockAtHeight(height uint32) ( endShortChanID = aliasmgr.StartingAlias removedChans []*models.ChannelEdgeInfo - ) - var chanIDStart [8]byte - byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) - var chanIDEnd [8]byte - byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) + chanIDStart = channelIDToBytes(startShortChanID.ToUint64()) + chanIDEnd = channelIDToBytes(endShortChanID.ToUint64()) + ) err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { rows, err := db.GetChannelsBySCIDRange( ctx, sqlc.GetChannelsBySCIDRangeParams{ - StartScid: chanIDStart[:], - EndScid: chanIDEnd[:], + StartScid: chanIDStart, + EndScid: chanIDEnd, }, ) if err != nil { @@ -2688,7 +2680,7 @@ func (s *SQLStore) AddEdgeProof(scid lnwire.ShortChannelID, err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { res, err := db.AddV1ChannelProof( ctx, sqlc.AddV1ChannelProofParams{ - Scid: scidBytes[:], + Scid: scidBytes, Node1Signature: proof.NodeSig1Bytes, Node2Signature: proof.NodeSig2Bytes, Bitcoin1Signature: proof.BitcoinSig1Bytes, @@ -2734,7 +2726,7 @@ func (s *SQLStore) PutClosedScid(scid lnwire.ShortChannelID) error { ) return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { - return db.InsertClosedChannel(ctx, chanIDB[:]) + return db.InsertClosedChannel(ctx, chanIDB) }, sqldb.NoOpReset) } @@ -2750,7 +2742,7 @@ func (s *SQLStore) IsClosedScid(scid lnwire.ShortChannelID) (bool, error) { ) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { var err error - isClosed, err = db.IsClosedChannel(ctx, chanIDB[:]) + isClosed, err = db.IsClosedChannel(ctx, chanIDB) if err != nil { return fmt.Errorf("unable to fetch closed channel: %w", err) @@ -3077,7 +3069,7 @@ func updateChanEdgePolicy(ctx context.Context, tx SQLQueries, // abort the transaction which would abort the entire batch. dbChan, err := tx.GetChannelAndNodesBySCID( ctx, sqlc.GetChannelAndNodesBySCIDParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -3779,7 +3771,7 @@ func insertChannel(ctx context.Context, db SQLQueries, // batch of transactions. _, err := db.GetChannelBySCID( ctx, sqlc.GetChannelBySCIDParams{ - Scid: chanIDB[:], + Scid: chanIDB, Version: int16(ProtocolV1), }, ) @@ -3808,7 +3800,7 @@ func insertChannel(ctx context.Context, db SQLQueries, createParams := sqlc.CreateChannelParams{ Version: int16(ProtocolV1), - Scid: chanIDB[:], + Scid: chanIDB, NodeID1: node1DBID, NodeID2: node2DBID, Outpoint: edge.ChannelPoint.String(), @@ -4455,9 +4447,9 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy, // channelIDToBytes converts a channel ID (SCID) to a byte array // representation. -func channelIDToBytes(channelID uint64) [8]byte { +func channelIDToBytes(channelID uint64) []byte { var chanIDB [8]byte byteOrder.PutUint64(chanIDB[:], channelID) - return chanIDB + return chanIDB[:] }