graph/db: return channel DB info from insertChannel

In preparation for the kvdb->migration code, this commit updates
`insertChannel` to return the ID of the newly inserted channel along
with the IDs of the nodes that the channel links to.
This commit is contained in:
Elle Mouton
2025-06-26 13:49:28 +02:00
parent 4bde8e2d04
commit 92849388b8

View File

@@ -570,7 +570,7 @@ func (s *SQLStore) AddChannelEdge(ctx context.Context,
alreadyExists = false alreadyExists = false
}, },
Do: func(tx SQLQueries) error { Do: func(tx SQLQueries) error {
err := insertChannel(ctx, tx, edge) _, err := insertChannel(ctx, tx, edge)
// Silence ErrEdgeAlreadyExist so that the batch can // Silence ErrEdgeAlreadyExist so that the batch can
// succeed, but propagate the error via local state. // succeed, but propagate the error via local state.
@@ -3759,9 +3759,17 @@ func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
return records, nil return records, nil
} }
// dbChanInfo holds the DB level IDs of a channel and the nodes involved in the
// channel.
type dbChanInfo struct {
channelID int64
node1ID int64
node2ID int64
}
// insertChannel inserts a new channel record into the database. // insertChannel inserts a new channel record into the database.
func insertChannel(ctx context.Context, db SQLQueries, func insertChannel(ctx context.Context, db SQLQueries,
edge *models.ChannelEdgeInfo) error { edge *models.ChannelEdgeInfo) (*dbChanInfo, error) {
chanIDB := channelIDToBytes(edge.ChannelID) chanIDB := channelIDToBytes(edge.ChannelID)
@@ -3776,21 +3784,21 @@ func insertChannel(ctx context.Context, db SQLQueries,
}, },
) )
if err == nil { if err == nil {
return ErrEdgeAlreadyExist return nil, ErrEdgeAlreadyExist
} else if !errors.Is(err, sql.ErrNoRows) { } else if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("unable to fetch channel: %w", err) return nil, fmt.Errorf("unable to fetch channel: %w", err)
} }
// Make sure that at least a "shell" entry for each node is present in // Make sure that at least a "shell" entry for each node is present in
// the nodes table. // the nodes table.
node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes) node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
if err != nil { if err != nil {
return fmt.Errorf("unable to create shell node: %w", err) return nil, fmt.Errorf("unable to create shell node: %w", err)
} }
node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes) node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
if err != nil { if err != nil {
return fmt.Errorf("unable to create shell node: %w", err) return nil, fmt.Errorf("unable to create shell node: %w", err)
} }
var capacity sql.NullInt64 var capacity sql.NullInt64
@@ -3821,7 +3829,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
// Insert the new channel record. // Insert the new channel record.
dbChanID, err := db.CreateChannel(ctx, createParams) dbChanID, err := db.CreateChannel(ctx, createParams)
if err != nil { if err != nil {
return err return nil, err
} }
// Insert any channel features. // Insert any channel features.
@@ -3829,7 +3837,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
chanFeatures := lnwire.NewRawFeatureVector() chanFeatures := lnwire.NewRawFeatureVector()
err := chanFeatures.Decode(bytes.NewReader(edge.Features)) err := chanFeatures.Decode(bytes.NewReader(edge.Features))
if err != nil { if err != nil {
return err return nil, err
} }
fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features) fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
@@ -3841,7 +3849,7 @@ func insertChannel(ctx context.Context, db SQLQueries,
}, },
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to insert "+ return nil, fmt.Errorf("unable to insert "+
"channel(%d) feature(%v): %w", dbChanID, "channel(%d) feature(%v): %w", dbChanID,
feature, err) feature, err)
} }
@@ -3851,8 +3859,8 @@ func insertChannel(ctx context.Context, db SQLQueries,
// Finally, insert any extra TLV fields in the channel announcement. // Finally, insert any extra TLV fields in the channel announcement.
extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData) extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
if err != nil { if err != nil {
return fmt.Errorf("unable to marshal extra opaque data: %w", return nil, fmt.Errorf("unable to marshal extra opaque "+
err) "data: %w", err)
} }
for tlvType, value := range extra { for tlvType, value := range extra {
@@ -3864,13 +3872,17 @@ func insertChannel(ctx context.Context, db SQLQueries,
}, },
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to upsert channel(%d) extra "+ return nil, fmt.Errorf("unable to upsert "+
"signed field(%v): %w", edge.ChannelID, "channel(%d) extra signed field(%v): %w",
tlvType, err) edge.ChannelID, tlvType, err)
} }
} }
return nil return &dbChanInfo{
channelID: dbChanID,
node1ID: node1DBID,
node2ID: node2DBID,
}, nil
} }
// maybeCreateShellNode checks if a shell node entry exists for the // maybeCreateShellNode checks if a shell node entry exists for the