Skip to content

Commit 3ab4762

Browse files
authored
Merge pull request #10121 from ellemouton/graphPerf6
[5] sqldb+graph/db: add and use new paginate & batch helper
2 parents 7152224 + ae13158 commit 3ab4762

File tree

6 files changed

+818
-149
lines changed

6 files changed

+818
-149
lines changed

graph/db/graph_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,6 +2510,11 @@ func TestFilterKnownChanIDs(t *testing.T) {
25102510
// methods that acquire the cache mutex along with the DB mutex.
25112511
func TestStressTestChannelGraphAPI(t *testing.T) {
25122512
t.Parallel()
2513+
2514+
if testing.Short() {
2515+
t.Skipf("Skipping test in short mode")
2516+
}
2517+
25132518
ctx := context.Background()
25142519

25152520
graph := MakeTestGraph(t)

graph/db/sql_migration_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,10 @@ func assertResultState(t *testing.T, sql *SQLStore, expState dbState) {
11771177
func TestMigrateGraphToSQLRapid(t *testing.T) {
11781178
t.Parallel()
11791179

1180+
if testing.Short() {
1181+
t.Skipf("skipping test in short mode")
1182+
}
1183+
11801184
dbFixture := NewTestDBFixture(t)
11811185

11821186
rapid.Check(t, func(rt *rapid.T) {

graph/db/sql_store.go

Lines changed: 177 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -806,49 +806,18 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
806806
func (s *SQLStore) ForEachNode(ctx context.Context,
807807
cb func(tx NodeRTx) error, reset func()) error {
808808

809-
var lastID int64
810-
811809
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
812-
nodeCB := func(dbID int64, node *models.LightningNode) error {
813-
err := cb(newSQLGraphNodeTx(
814-
db, s.cfg.ChainHash, dbID, node,
815-
))
816-
if err != nil {
817-
return fmt.Errorf("callback failed for "+
818-
"node(id=%d): %w", dbID, err)
819-
}
820-
lastID = dbID
821-
822-
return nil
823-
}
824-
825-
for {
826-
nodes, err := db.ListNodesPaginated(
827-
ctx, sqlc.ListNodesPaginatedParams{
828-
Version: int16(ProtocolV1),
829-
ID: lastID,
830-
Limit: s.cfg.QueryCfg.MaxPageSize,
831-
},
832-
)
833-
if err != nil {
834-
return fmt.Errorf("unable to fetch nodes: %w",
835-
err)
836-
}
837-
838-
if len(nodes) == 0 {
839-
break
840-
}
841-
842-
err = forEachNodeInBatch(
843-
ctx, s.cfg.QueryCfg, db, nodes, nodeCB,
844-
)
845-
if err != nil {
846-
return fmt.Errorf("unable to iterate over "+
847-
"nodes: %w", err)
848-
}
849-
}
810+
return forEachNodePaginated(
811+
ctx, s.cfg.QueryCfg, db,
812+
ProtocolV1,
813+
func(ctx context.Context, dbNodeID int64,
814+
node *models.LightningNode) error {
850815

851-
return nil
816+
return cb(newSQLGraphNodeTx(
817+
db, s.cfg.ChainHash, dbNodeID, node,
818+
))
819+
},
820+
)
852821
}, reset)
853822
}
854823

@@ -1328,115 +1297,8 @@ func (s *SQLStore) ForEachChannel(ctx context.Context,
13281297
cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
13291298
*models.ChannelEdgePolicy) error, reset func()) error {
13301299

1331-
handleChannel := func(db SQLQueries, batchData *batchChannelData,
1332-
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
1333-
1334-
node1, node2, err := buildNodeVertices(
1335-
row.Node1Pubkey, row.Node2Pubkey,
1336-
)
1337-
if err != nil {
1338-
return fmt.Errorf("unable to build node vertices: %w",
1339-
err)
1340-
}
1341-
1342-
edge, err := buildEdgeInfoWithBatchData(
1343-
s.cfg.ChainHash, row.GraphChannel, node1, node2,
1344-
batchData,
1345-
)
1346-
if err != nil {
1347-
return fmt.Errorf("unable to build channel info: %w",
1348-
err)
1349-
}
1350-
1351-
dbPol1, dbPol2, err := extractChannelPolicies(row)
1352-
if err != nil {
1353-
return fmt.Errorf("unable to extract channel "+
1354-
"policies: %w", err)
1355-
}
1356-
1357-
p1, p2, err := buildChanPoliciesWithBatchData(
1358-
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
1359-
)
1360-
if err != nil {
1361-
return fmt.Errorf("unable to build channel "+
1362-
"policies: %w", err)
1363-
}
1364-
1365-
err = cb(edge, p1, p2)
1366-
if err != nil {
1367-
return fmt.Errorf("callback failed for channel "+
1368-
"id=%d: %w", edge.ChannelID, err)
1369-
}
1370-
1371-
return nil
1372-
}
1373-
13741300
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
1375-
lastID := int64(-1)
1376-
for {
1377-
//nolint:ll
1378-
rows, err := db.ListChannelsWithPoliciesPaginated(
1379-
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
1380-
Version: int16(ProtocolV1),
1381-
ID: lastID,
1382-
Limit: s.cfg.QueryCfg.MaxPageSize,
1383-
},
1384-
)
1385-
if err != nil {
1386-
return err
1387-
}
1388-
1389-
if len(rows) == 0 {
1390-
break
1391-
}
1392-
1393-
// Collect the channel & policy IDs that we want to
1394-
// do a batch collection for.
1395-
var (
1396-
channelIDs = make([]int64, len(rows))
1397-
policyIDs = make([]int64, 0, len(rows)*2)
1398-
)
1399-
for i, row := range rows {
1400-
channelIDs[i] = row.GraphChannel.ID
1401-
1402-
// Extract policy IDs from the row
1403-
dbPol1, dbPol2, err := extractChannelPolicies(
1404-
row,
1405-
)
1406-
if err != nil {
1407-
return fmt.Errorf("unable to extract "+
1408-
"channel policies: %w", err)
1409-
}
1410-
1411-
if dbPol1 != nil {
1412-
policyIDs = append(policyIDs, dbPol1.ID)
1413-
}
1414-
1415-
if dbPol2 != nil {
1416-
policyIDs = append(policyIDs, dbPol2.ID)
1417-
}
1418-
}
1419-
1420-
batchData, err := batchLoadChannelData(
1421-
ctx, s.cfg.QueryCfg, db, channelIDs,
1422-
policyIDs,
1423-
)
1424-
if err != nil {
1425-
return fmt.Errorf("unable to batch load "+
1426-
"channel data: %w", err)
1427-
}
1428-
1429-
for _, row := range rows {
1430-
err := handleChannel(db, batchData, row)
1431-
if err != nil {
1432-
return err
1433-
}
1434-
1435-
lastID = row.GraphChannel.ID
1436-
}
1437-
}
1438-
1439-
return nil
1301+
return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
14401302
}, reset)
14411303
}
14421304

@@ -5082,3 +4944,169 @@ func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
50824944
},
50834945
)
50844946
}
4947+
4948+
// forEachNodePaginated executes a paginated query to process each node in the
4949+
// graph. It uses the provided SQLQueries interface to fetch nodes in batches
4950+
// and applies the provided processNode function to each node.
4951+
func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
4952+
db SQLQueries, protocol ProtocolVersion,
4953+
processNode func(context.Context, int64,
4954+
*models.LightningNode) error) error {
4955+
4956+
pageQueryFunc := func(ctx context.Context, lastID int64,
4957+
limit int32) ([]sqlc.GraphNode, error) {
4958+
4959+
return db.ListNodesPaginated(
4960+
ctx, sqlc.ListNodesPaginatedParams{
4961+
Version: int16(protocol),
4962+
ID: lastID,
4963+
Limit: limit,
4964+
},
4965+
)
4966+
}
4967+
4968+
extractPageCursor := func(node sqlc.GraphNode) int64 {
4969+
return node.ID
4970+
}
4971+
4972+
collectFunc := func(node sqlc.GraphNode) (int64, error) {
4973+
return node.ID, nil
4974+
}
4975+
4976+
batchQueryFunc := func(ctx context.Context,
4977+
nodeIDs []int64) (*batchNodeData, error) {
4978+
4979+
return batchLoadNodeData(ctx, cfg, db, nodeIDs)
4980+
}
4981+
4982+
processItem := func(ctx context.Context, dbNode sqlc.GraphNode,
4983+
batchData *batchNodeData) error {
4984+
4985+
node, err := buildNodeWithBatchData(&dbNode, batchData)
4986+
if err != nil {
4987+
return fmt.Errorf("unable to build "+
4988+
"node(id=%d): %w", dbNode.ID, err)
4989+
}
4990+
4991+
return processNode(ctx, dbNode.ID, node)
4992+
}
4993+
4994+
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
4995+
ctx, cfg, int64(-1), pageQueryFunc, extractPageCursor,
4996+
collectFunc, batchQueryFunc, processItem,
4997+
)
4998+
}
4999+
5000+
// forEachChannelWithPolicies executes a paginated query to process each channel
5001+
// with policies in the graph.
5002+
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
5003+
cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
5004+
*models.ChannelEdgePolicy,
5005+
*models.ChannelEdgePolicy) error) error {
5006+
5007+
type channelBatchIDs struct {
5008+
channelID int64
5009+
policyIDs []int64
5010+
}
5011+
5012+
pageQueryFunc := func(ctx context.Context, lastID int64,
5013+
limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
5014+
error) {
5015+
5016+
return db.ListChannelsWithPoliciesPaginated(
5017+
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
5018+
Version: int16(ProtocolV1),
5019+
ID: lastID,
5020+
Limit: limit,
5021+
},
5022+
)
5023+
}
5024+
5025+
extractPageCursor := func(
5026+
row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
5027+
5028+
return row.GraphChannel.ID
5029+
}
5030+
5031+
collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
5032+
channelBatchIDs, error) {
5033+
5034+
ids := channelBatchIDs{
5035+
channelID: row.GraphChannel.ID,
5036+
}
5037+
5038+
// Extract policy IDs from the row.
5039+
dbPol1, dbPol2, err := extractChannelPolicies(row)
5040+
if err != nil {
5041+
return ids, err
5042+
}
5043+
5044+
if dbPol1 != nil {
5045+
ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
5046+
}
5047+
if dbPol2 != nil {
5048+
ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
5049+
}
5050+
5051+
return ids, nil
5052+
}
5053+
5054+
batchDataFunc := func(ctx context.Context,
5055+
allIDs []channelBatchIDs) (*batchChannelData, error) {
5056+
5057+
// Separate channel IDs from policy IDs.
5058+
var (
5059+
channelIDs = make([]int64, len(allIDs))
5060+
policyIDs = make([]int64, 0, len(allIDs)*2)
5061+
)
5062+
5063+
for i, ids := range allIDs {
5064+
channelIDs[i] = ids.channelID
5065+
policyIDs = append(policyIDs, ids.policyIDs...)
5066+
}
5067+
5068+
return batchLoadChannelData(
5069+
ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
5070+
)
5071+
}
5072+
5073+
processItem := func(ctx context.Context,
5074+
row sqlc.ListChannelsWithPoliciesPaginatedRow,
5075+
batchData *batchChannelData) error {
5076+
5077+
node1, node2, err := buildNodeVertices(
5078+
row.Node1Pubkey, row.Node2Pubkey,
5079+
)
5080+
if err != nil {
5081+
return err
5082+
}
5083+
5084+
edge, err := buildEdgeInfoWithBatchData(
5085+
cfg.ChainHash, row.GraphChannel, node1, node2,
5086+
batchData,
5087+
)
5088+
if err != nil {
5089+
return fmt.Errorf("unable to build channel info: %w",
5090+
err)
5091+
}
5092+
5093+
dbPol1, dbPol2, err := extractChannelPolicies(row)
5094+
if err != nil {
5095+
return err
5096+
}
5097+
5098+
p1, p2, err := buildChanPoliciesWithBatchData(
5099+
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
5100+
)
5101+
if err != nil {
5102+
return err
5103+
}
5104+
5105+
return processChannel(edge, p1, p2)
5106+
}
5107+
5108+
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
5109+
ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
5110+
collectFunc, batchDataFunc, processItem,
5111+
)
5112+
}

make/testing_flags.mk

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ ifneq ($(nocache),)
121121
TEST_FLAGS += -test.count=1
122122
endif
123123

124+
# If the short flag is added, then any unit tests marked with "testing.Short()"
125+
# will be skipped.
126+
ifneq ($(short),)
127+
TEST_FLAGS += -short
128+
endif
129+
124130
GOLIST := $(GOCC) list -tags="$(DEV_TAGS)" -deps $(PKG)/... | grep '$(PKG)'| grep -v '/vendor/'
125131

126132
# UNIT_TARGTED is undefined iff a specific package and/or unit test case is

0 commit comments

Comments
 (0)