mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-12-01 08:29:00 +01:00
graph/db+sqldb: implement HasChannelEdge and ChannelID
And run `TestEdgeInfoUpdates` against our SQL backends.
This commit is contained in:
@@ -812,7 +812,7 @@ func TestEdgeInfoUpdates(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
graph := MakeTestGraph(t)
|
graph := MakeTestGraphNew(t)
|
||||||
|
|
||||||
// We'd like to test the update of edges inserted into the database, so
|
// We'd like to test the update of edges inserted into the database, so
|
||||||
// we create two vertexes to connect.
|
// we create two vertexes to connect.
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ type SQLQueries interface {
|
|||||||
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
|
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
|
||||||
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
|
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
|
||||||
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
|
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
|
||||||
|
GetSCIDByOutpoint(ctx context.Context, arg sqlc.GetSCIDByOutpointParams) ([]byte, error)
|
||||||
DeleteChannel(ctx context.Context, id int64) error
|
DeleteChannel(ctx context.Context, id int64) error
|
||||||
|
|
||||||
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
|
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
|
||||||
@@ -1833,6 +1834,162 @@ func (s *SQLStore) FetchChannelEdgesByOutpoint(op *wire.OutPoint) (
|
|||||||
return edge, policy1, policy2, nil
|
return edge, policy1, policy2, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasChannelEdge returns true if the database knows of a channel edge with the
|
||||||
|
// passed channel ID, and false otherwise. If an edge with that ID is found
|
||||||
|
// within the graph, then two time stamps representing the last time the edge
|
||||||
|
// was updated for both directed edges are returned along with the boolean. If
|
||||||
|
// it is not found, then the zombie index is checked and its result is returned
|
||||||
|
// as the second boolean.
|
||||||
|
//
|
||||||
|
// NOTE: part of the V1Store interface.
|
||||||
|
func (s *SQLStore) HasChannelEdge(chanID uint64) (time.Time, time.Time, bool,
|
||||||
|
bool, error) {
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
|
|
||||||
|
var (
|
||||||
|
exists bool
|
||||||
|
isZombie bool
|
||||||
|
node1LastUpdate time.Time
|
||||||
|
node2LastUpdate time.Time
|
||||||
|
)
|
||||||
|
|
||||||
|
// We'll query the cache with the shared lock held to allow multiple
|
||||||
|
// readers to access values in the cache concurrently if they exist.
|
||||||
|
s.cacheMu.RLock()
|
||||||
|
if entry, ok := s.rejectCache.get(chanID); ok {
|
||||||
|
s.cacheMu.RUnlock()
|
||||||
|
node1LastUpdate = time.Unix(entry.upd1Time, 0)
|
||||||
|
node2LastUpdate = time.Unix(entry.upd2Time, 0)
|
||||||
|
exists, isZombie = entry.flags.unpack()
|
||||||
|
|
||||||
|
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
|
||||||
|
}
|
||||||
|
s.cacheMu.RUnlock()
|
||||||
|
|
||||||
|
s.cacheMu.Lock()
|
||||||
|
defer s.cacheMu.Unlock()
|
||||||
|
|
||||||
|
// The item was not found with the shared lock, so we'll acquire the
|
||||||
|
// exclusive lock and check the cache again in case another method added
|
||||||
|
// the entry to the cache while no lock was held.
|
||||||
|
if entry, ok := s.rejectCache.get(chanID); ok {
|
||||||
|
node1LastUpdate = time.Unix(entry.upd1Time, 0)
|
||||||
|
node2LastUpdate = time.Unix(entry.upd2Time, 0)
|
||||||
|
exists, isZombie = entry.flags.unpack()
|
||||||
|
|
||||||
|
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||||
|
var chanIDB [8]byte
|
||||||
|
byteOrder.PutUint64(chanIDB[:], chanID)
|
||||||
|
|
||||||
|
channel, err := db.GetChannelBySCID(
|
||||||
|
ctx, sqlc.GetChannelBySCIDParams{
|
||||||
|
Scid: chanIDB[:],
|
||||||
|
Version: int16(ProtocolV1),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
// Check if it is a zombie channel.
|
||||||
|
isZombie, err = db.IsZombieChannel(
|
||||||
|
ctx, sqlc.IsZombieChannelParams{
|
||||||
|
Scid: chanIDB[:],
|
||||||
|
Version: int16(ProtocolV1),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not check if channel "+
|
||||||
|
"is zombie: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("unable to fetch channel: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists = true
|
||||||
|
|
||||||
|
policy1, err := db.GetChannelPolicyByChannelAndNode(
|
||||||
|
ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
|
||||||
|
Version: int16(ProtocolV1),
|
||||||
|
ChannelID: channel.ID,
|
||||||
|
NodeID: channel.NodeID1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return fmt.Errorf("unable to fetch channel policy: %w",
|
||||||
|
err)
|
||||||
|
} else if err == nil {
|
||||||
|
node1LastUpdate = time.Unix(policy1.LastUpdate.Int64, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
policy2, err := db.GetChannelPolicyByChannelAndNode(
|
||||||
|
ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
|
||||||
|
Version: int16(ProtocolV1),
|
||||||
|
ChannelID: channel.ID,
|
||||||
|
NodeID: channel.NodeID2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return fmt.Errorf("unable to fetch channel policy: %w",
|
||||||
|
err)
|
||||||
|
} else if err == nil {
|
||||||
|
node2LastUpdate = time.Unix(policy2.LastUpdate.Int64, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, sqldb.NoOpReset)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, time.Time{}, false, false,
|
||||||
|
fmt.Errorf("unable to fetch channel: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.rejectCache.insert(chanID, rejectCacheEntry{
|
||||||
|
upd1Time: node1LastUpdate.Unix(),
|
||||||
|
upd2Time: node2LastUpdate.Unix(),
|
||||||
|
flags: packRejectFlags(exists, isZombie),
|
||||||
|
})
|
||||||
|
|
||||||
|
return node1LastUpdate, node2LastUpdate, exists, isZombie, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
|
||||||
|
// passed channel point (outpoint). If the passed channel doesn't exist within
|
||||||
|
// the database, then ErrEdgeNotFound is returned.
|
||||||
|
//
|
||||||
|
// NOTE: part of the V1Store interface.
|
||||||
|
func (s *SQLStore) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
|
||||||
|
var (
|
||||||
|
ctx = context.TODO()
|
||||||
|
channelID uint64
|
||||||
|
)
|
||||||
|
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||||
|
chanID, err := db.GetSCIDByOutpoint(
|
||||||
|
ctx, sqlc.GetSCIDByOutpointParams{
|
||||||
|
Outpoint: chanPoint.String(),
|
||||||
|
Version: int16(ProtocolV1),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return ErrEdgeNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("unable to fetch channel ID: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
channelID = byteOrder.Uint64(chanID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, sqldb.NoOpReset)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("unable to fetch channel ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return channelID, nil
|
||||||
|
}
|
||||||
|
|
||||||
// forEachNodeDirectedChannel iterates through all channels of a given
|
// forEachNodeDirectedChannel iterates through all channels of a given
|
||||||
// node, executing the passed callback on the directed edge representing the
|
// node, executing the passed callback on the directed edge representing the
|
||||||
// channel and its incoming policy. If the node is not found, no error is
|
// channel and its incoming policy. If the node is not found, no error is
|
||||||
@@ -3234,6 +3391,46 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
|
|||||||
|
|
||||||
var policy1, policy2 *sqlc.ChannelPolicy
|
var policy1, policy2 *sqlc.ChannelPolicy
|
||||||
switch r := row.(type) {
|
switch r := row.(type) {
|
||||||
|
case sqlc.GetChannelByOutpointWithPoliciesRow:
|
||||||
|
if r.Policy1ID.Valid {
|
||||||
|
policy1 = &sqlc.ChannelPolicy{
|
||||||
|
ID: r.Policy1ID.Int64,
|
||||||
|
Version: r.Policy1Version.Int16,
|
||||||
|
ChannelID: r.Channel.ID,
|
||||||
|
NodeID: r.Policy1NodeID.Int64,
|
||||||
|
Timelock: r.Policy1Timelock.Int32,
|
||||||
|
FeePpm: r.Policy1FeePpm.Int64,
|
||||||
|
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
|
||||||
|
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
|
||||||
|
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
|
||||||
|
LastUpdate: r.Policy1LastUpdate,
|
||||||
|
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
|
||||||
|
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
|
||||||
|
Disabled: r.Policy1Disabled,
|
||||||
|
Signature: r.Policy1Signature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Policy2ID.Valid {
|
||||||
|
policy2 = &sqlc.ChannelPolicy{
|
||||||
|
ID: r.Policy2ID.Int64,
|
||||||
|
Version: r.Policy2Version.Int16,
|
||||||
|
ChannelID: r.Channel.ID,
|
||||||
|
NodeID: r.Policy2NodeID.Int64,
|
||||||
|
Timelock: r.Policy2Timelock.Int32,
|
||||||
|
FeePpm: r.Policy2FeePpm.Int64,
|
||||||
|
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
|
||||||
|
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
|
||||||
|
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
|
||||||
|
LastUpdate: r.Policy2LastUpdate,
|
||||||
|
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
|
||||||
|
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
|
||||||
|
Disabled: r.Policy2Disabled,
|
||||||
|
Signature: r.Policy2Signature,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return policy1, policy2, nil
|
||||||
|
|
||||||
case sqlc.GetChannelBySCIDWithPoliciesRow:
|
case sqlc.GetChannelBySCIDWithPoliciesRow:
|
||||||
if r.Policy1ID.Valid {
|
if r.Policy1ID.Valid {
|
||||||
policy1 = &sqlc.ChannelPolicy{
|
policy1 = &sqlc.ChannelPolicy{
|
||||||
|
|||||||
@@ -1167,6 +1167,23 @@ func (q *Queries) GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getSCIDByOutpoint = `-- name: GetSCIDByOutpoint :one
|
||||||
|
SELECT scid from channels
|
||||||
|
WHERE outpoint = $1 AND version = $2
|
||||||
|
`
|
||||||
|
|
||||||
|
type GetSCIDByOutpointParams struct {
|
||||||
|
Outpoint string
|
||||||
|
Version int16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) GetSCIDByOutpoint(ctx context.Context, arg GetSCIDByOutpointParams) ([]byte, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getSCIDByOutpoint, arg.Outpoint, arg.Version)
|
||||||
|
var scid []byte
|
||||||
|
err := row.Scan(&scid)
|
||||||
|
return scid, err
|
||||||
|
}
|
||||||
|
|
||||||
const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many
|
const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many
|
||||||
SELECT sn.node_id, n.pub_key
|
SELECT sn.node_id, n.pub_key
|
||||||
FROM source_nodes sn
|
FROM source_nodes sn
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ type Querier interface {
|
|||||||
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
|
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
|
||||||
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
|
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
|
||||||
GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error)
|
GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error)
|
||||||
|
GetSCIDByOutpoint(ctx context.Context, arg GetSCIDByOutpointParams) ([]byte, error)
|
||||||
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
|
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
|
||||||
GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error)
|
GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (ZombieChannel, error)
|
||||||
HighestSCID(ctx context.Context, version int16) ([]byte, error)
|
HighestSCID(ctx context.Context, version int16) ([]byte, error)
|
||||||
|
|||||||
@@ -206,6 +206,10 @@ SELECT
|
|||||||
FROM channel_extra_types cet
|
FROM channel_extra_types cet
|
||||||
WHERE cet.channel_id = $1;
|
WHERE cet.channel_id = $1;
|
||||||
|
|
||||||
|
-- name: GetSCIDByOutpoint :one
|
||||||
|
SELECT scid from channels
|
||||||
|
WHERE outpoint = $1 AND version = $2;
|
||||||
|
|
||||||
-- name: GetChannelsByPolicyLastUpdateRange :many
|
-- name: GetChannelsByPolicyLastUpdateRange :many
|
||||||
SELECT
|
SELECT
|
||||||
sqlc.embed(c),
|
sqlc.embed(c),
|
||||||
|
|||||||
Reference in New Issue
Block a user