@@ -806,49 +806,18 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
806806func (s * SQLStore ) ForEachNode (ctx context.Context ,
807807 cb func (tx NodeRTx ) error , reset func ()) error {
808808
809- var lastID int64
810-
811809 return s .db .ExecTx (ctx , sqldb .ReadTxOpt (), func (db SQLQueries ) error {
812- nodeCB := func (dbID int64 , node * models.LightningNode ) error {
813- err := cb (newSQLGraphNodeTx (
814- db , s .cfg .ChainHash , dbID , node ,
815- ))
816- if err != nil {
817- return fmt .Errorf ("callback failed for " +
818- "node(id=%d): %w" , dbID , err )
819- }
820- lastID = dbID
821-
822- return nil
823- }
824-
825- for {
826- nodes , err := db .ListNodesPaginated (
827- ctx , sqlc.ListNodesPaginatedParams {
828- Version : int16 (ProtocolV1 ),
829- ID : lastID ,
830- Limit : s .cfg .QueryCfg .MaxPageSize ,
831- },
832- )
833- if err != nil {
834- return fmt .Errorf ("unable to fetch nodes: %w" ,
835- err )
836- }
837-
838- if len (nodes ) == 0 {
839- break
840- }
841-
842- err = forEachNodeInBatch (
843- ctx , s .cfg .QueryCfg , db , nodes , nodeCB ,
844- )
845- if err != nil {
846- return fmt .Errorf ("unable to iterate over " +
847- "nodes: %w" , err )
848- }
849- }
810+ return forEachNodePaginated (
811+ ctx , s .cfg .QueryCfg , db ,
812+ ProtocolV1 ,
813+ func (ctx context.Context , dbNodeID int64 ,
814+ node * models.LightningNode ) error {
850815
851- return nil
816+ return cb (newSQLGraphNodeTx (
817+ db , s .cfg .ChainHash , dbNodeID , node ,
818+ ))
819+ },
820+ )
852821 }, reset )
853822}
854823
@@ -1328,115 +1297,8 @@ func (s *SQLStore) ForEachChannel(ctx context.Context,
13281297 cb func (* models.ChannelEdgeInfo , * models.ChannelEdgePolicy ,
13291298 * models.ChannelEdgePolicy ) error , reset func ()) error {
13301299
1331- handleChannel := func (db SQLQueries , batchData * batchChannelData ,
1332- row sqlc.ListChannelsWithPoliciesPaginatedRow ) error {
1333-
1334- node1 , node2 , err := buildNodeVertices (
1335- row .Node1Pubkey , row .Node2Pubkey ,
1336- )
1337- if err != nil {
1338- return fmt .Errorf ("unable to build node vertices: %w" ,
1339- err )
1340- }
1341-
1342- edge , err := buildEdgeInfoWithBatchData (
1343- s .cfg .ChainHash , row .GraphChannel , node1 , node2 ,
1344- batchData ,
1345- )
1346- if err != nil {
1347- return fmt .Errorf ("unable to build channel info: %w" ,
1348- err )
1349- }
1350-
1351- dbPol1 , dbPol2 , err := extractChannelPolicies (row )
1352- if err != nil {
1353- return fmt .Errorf ("unable to extract channel " +
1354- "policies: %w" , err )
1355- }
1356-
1357- p1 , p2 , err := buildChanPoliciesWithBatchData (
1358- dbPol1 , dbPol2 , edge .ChannelID , node1 , node2 , batchData ,
1359- )
1360- if err != nil {
1361- return fmt .Errorf ("unable to build channel " +
1362- "policies: %w" , err )
1363- }
1364-
1365- err = cb (edge , p1 , p2 )
1366- if err != nil {
1367- return fmt .Errorf ("callback failed for channel " +
1368- "id=%d: %w" , edge .ChannelID , err )
1369- }
1370-
1371- return nil
1372- }
1373-
13741300 return s .db .ExecTx (ctx , sqldb .ReadTxOpt (), func (db SQLQueries ) error {
1375- lastID := int64 (- 1 )
1376- for {
1377- //nolint:ll
1378- rows , err := db .ListChannelsWithPoliciesPaginated (
1379- ctx , sqlc.ListChannelsWithPoliciesPaginatedParams {
1380- Version : int16 (ProtocolV1 ),
1381- ID : lastID ,
1382- Limit : s .cfg .QueryCfg .MaxPageSize ,
1383- },
1384- )
1385- if err != nil {
1386- return err
1387- }
1388-
1389- if len (rows ) == 0 {
1390- break
1391- }
1392-
1393- // Collect the channel & policy IDs that we want to
1394- // do a batch collection for.
1395- var (
1396- channelIDs = make ([]int64 , len (rows ))
1397- policyIDs = make ([]int64 , 0 , len (rows )* 2 )
1398- )
1399- for i , row := range rows {
1400- channelIDs [i ] = row .GraphChannel .ID
1401-
1402- // Extract policy IDs from the row
1403- dbPol1 , dbPol2 , err := extractChannelPolicies (
1404- row ,
1405- )
1406- if err != nil {
1407- return fmt .Errorf ("unable to extract " +
1408- "channel policies: %w" , err )
1409- }
1410-
1411- if dbPol1 != nil {
1412- policyIDs = append (policyIDs , dbPol1 .ID )
1413- }
1414-
1415- if dbPol2 != nil {
1416- policyIDs = append (policyIDs , dbPol2 .ID )
1417- }
1418- }
1419-
1420- batchData , err := batchLoadChannelData (
1421- ctx , s .cfg .QueryCfg , db , channelIDs ,
1422- policyIDs ,
1423- )
1424- if err != nil {
1425- return fmt .Errorf ("unable to batch load " +
1426- "channel data: %w" , err )
1427- }
1428-
1429- for _ , row := range rows {
1430- err := handleChannel (db , batchData , row )
1431- if err != nil {
1432- return err
1433- }
1434-
1435- lastID = row .GraphChannel .ID
1436- }
1437- }
1438-
1439- return nil
1301+ return forEachChannelWithPolicies (ctx , db , s .cfg , cb )
14401302 }, reset )
14411303}
14421304
@@ -5082,3 +4944,169 @@ func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
50824944 },
50834945 )
50844946}
4947+
4948+ // forEachNodePaginated executes a paginated query to process each node in the
4949+ // graph. It uses the provided SQLQueries interface to fetch nodes in batches
4950+ // and applies the provided processNode function to each node.
4951+ func forEachNodePaginated (ctx context.Context , cfg * sqldb.QueryConfig ,
4952+ db SQLQueries , protocol ProtocolVersion ,
4953+ processNode func (context.Context , int64 ,
4954+ * models.LightningNode ) error ) error {
4955+
4956+ pageQueryFunc := func (ctx context.Context , lastID int64 ,
4957+ limit int32 ) ([]sqlc.GraphNode , error ) {
4958+
4959+ return db .ListNodesPaginated (
4960+ ctx , sqlc.ListNodesPaginatedParams {
4961+ Version : int16 (protocol ),
4962+ ID : lastID ,
4963+ Limit : limit ,
4964+ },
4965+ )
4966+ }
4967+
4968+ extractPageCursor := func (node sqlc.GraphNode ) int64 {
4969+ return node .ID
4970+ }
4971+
4972+ collectFunc := func (node sqlc.GraphNode ) (int64 , error ) {
4973+ return node .ID , nil
4974+ }
4975+
4976+ batchQueryFunc := func (ctx context.Context ,
4977+ nodeIDs []int64 ) (* batchNodeData , error ) {
4978+
4979+ return batchLoadNodeData (ctx , cfg , db , nodeIDs )
4980+ }
4981+
4982+ processItem := func (ctx context.Context , dbNode sqlc.GraphNode ,
4983+ batchData * batchNodeData ) error {
4984+
4985+ node , err := buildNodeWithBatchData (& dbNode , batchData )
4986+ if err != nil {
4987+ return fmt .Errorf ("unable to build " +
4988+ "node(id=%d): %w" , dbNode .ID , err )
4989+ }
4990+
4991+ return processNode (ctx , dbNode .ID , node )
4992+ }
4993+
4994+ return sqldb .ExecuteCollectAndBatchWithSharedDataQuery (
4995+ ctx , cfg , int64 (- 1 ), pageQueryFunc , extractPageCursor ,
4996+ collectFunc , batchQueryFunc , processItem ,
4997+ )
4998+ }
4999+
5000+ // forEachChannelWithPolicies executes a paginated query to process each channel
5001+ // with policies in the graph.
5002+ func forEachChannelWithPolicies (ctx context.Context , db SQLQueries ,
5003+ cfg * SQLStoreConfig , processChannel func (* models.ChannelEdgeInfo ,
5004+ * models.ChannelEdgePolicy ,
5005+ * models.ChannelEdgePolicy ) error ) error {
5006+
5007+ type channelBatchIDs struct {
5008+ channelID int64
5009+ policyIDs []int64
5010+ }
5011+
5012+ pageQueryFunc := func (ctx context.Context , lastID int64 ,
5013+ limit int32 ) ([]sqlc.ListChannelsWithPoliciesPaginatedRow ,
5014+ error ) {
5015+
5016+ return db .ListChannelsWithPoliciesPaginated (
5017+ ctx , sqlc.ListChannelsWithPoliciesPaginatedParams {
5018+ Version : int16 (ProtocolV1 ),
5019+ ID : lastID ,
5020+ Limit : limit ,
5021+ },
5022+ )
5023+ }
5024+
5025+ extractPageCursor := func (
5026+ row sqlc.ListChannelsWithPoliciesPaginatedRow ) int64 {
5027+
5028+ return row .GraphChannel .ID
5029+ }
5030+
5031+ collectFunc := func (row sqlc.ListChannelsWithPoliciesPaginatedRow ) (
5032+ channelBatchIDs , error ) {
5033+
5034+ ids := channelBatchIDs {
5035+ channelID : row .GraphChannel .ID ,
5036+ }
5037+
5038+ // Extract policy IDs from the row.
5039+ dbPol1 , dbPol2 , err := extractChannelPolicies (row )
5040+ if err != nil {
5041+ return ids , err
5042+ }
5043+
5044+ if dbPol1 != nil {
5045+ ids .policyIDs = append (ids .policyIDs , dbPol1 .ID )
5046+ }
5047+ if dbPol2 != nil {
5048+ ids .policyIDs = append (ids .policyIDs , dbPol2 .ID )
5049+ }
5050+
5051+ return ids , nil
5052+ }
5053+
5054+ batchDataFunc := func (ctx context.Context ,
5055+ allIDs []channelBatchIDs ) (* batchChannelData , error ) {
5056+
5057+ // Separate channel IDs from policy IDs.
5058+ var (
5059+ channelIDs = make ([]int64 , len (allIDs ))
5060+ policyIDs = make ([]int64 , 0 , len (allIDs )* 2 )
5061+ )
5062+
5063+ for i , ids := range allIDs {
5064+ channelIDs [i ] = ids .channelID
5065+ policyIDs = append (policyIDs , ids .policyIDs ... )
5066+ }
5067+
5068+ return batchLoadChannelData (
5069+ ctx , cfg .QueryCfg , db , channelIDs , policyIDs ,
5070+ )
5071+ }
5072+
5073+ processItem := func (ctx context.Context ,
5074+ row sqlc.ListChannelsWithPoliciesPaginatedRow ,
5075+ batchData * batchChannelData ) error {
5076+
5077+ node1 , node2 , err := buildNodeVertices (
5078+ row .Node1Pubkey , row .Node2Pubkey ,
5079+ )
5080+ if err != nil {
5081+ return err
5082+ }
5083+
5084+ edge , err := buildEdgeInfoWithBatchData (
5085+ cfg .ChainHash , row .GraphChannel , node1 , node2 ,
5086+ batchData ,
5087+ )
5088+ if err != nil {
5089+ return fmt .Errorf ("unable to build channel info: %w" ,
5090+ err )
5091+ }
5092+
5093+ dbPol1 , dbPol2 , err := extractChannelPolicies (row )
5094+ if err != nil {
5095+ return err
5096+ }
5097+
5098+ p1 , p2 , err := buildChanPoliciesWithBatchData (
5099+ dbPol1 , dbPol2 , edge .ChannelID , node1 , node2 , batchData ,
5100+ )
5101+ if err != nil {
5102+ return err
5103+ }
5104+
5105+ return processChannel (edge , p1 , p2 )
5106+ }
5107+
5108+ return sqldb .ExecuteCollectAndBatchWithSharedDataQuery (
5109+ ctx , cfg .QueryCfg , int64 (- 1 ), pageQueryFunc , extractPageCursor ,
5110+ collectFunc , batchDataFunc , processItem ,
5111+ )
5112+ }
0 commit comments