Skip to content

Commit 43d523c

Browse files
Make Node2Vec respect TerminationFlag during Training
Co-authored-by: Ioannis Panagiotas <[email protected]>
1 parent 90cec5a commit 43d523c

File tree

5 files changed

+80
-11
lines changed

5 files changed

+80
-11
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ public Node2VecResult compute() {
130130
maybeRandomSeed,
131131
walks,
132132
probabilitiesBuilder.build(),
133-
progressTracker
133+
progressTracker,
134+
terminationFlag
134135
);
135136

136137
var result = node2VecModel.train();

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2828
import org.neo4j.gds.mem.BitUtil;
2929
import org.neo4j.gds.ml.core.tensor.FloatVector;
30+
import org.neo4j.gds.termination.TerminationFlag;
3031

3132
import java.util.ArrayList;
3233
import java.util.List;
@@ -54,6 +55,7 @@ public class Node2VecModel {
5455
private final RandomWalkProbabilities randomWalkProbabilities;
5556
private final ProgressTracker progressTracker;
5657
private final long randomSeed;
58+
private final TerminationFlag terminationFlag;
5759

5860
static final double EPSILON = 1e-10;
5961

@@ -65,7 +67,8 @@ public class Node2VecModel {
6567
Optional<Long> maybeRandomSeed,
6668
CompressedRandomWalks walks,
6769
RandomWalkProbabilities randomWalkProbabilities,
68-
ProgressTracker progressTracker
70+
ProgressTracker progressTracker,
71+
TerminationFlag terminationFlag
6972
) {
7073
this(
7174
toOriginalId,
@@ -81,7 +84,8 @@ public class Node2VecModel {
8184
maybeRandomSeed,
8285
walks,
8386
randomWalkProbabilities,
84-
progressTracker
87+
progressTracker,
88+
terminationFlag
8589
);
8690
}
8791

@@ -99,7 +103,8 @@ public class Node2VecModel {
99103
Optional<Long> maybeRandomSeed,
100104
CompressedRandomWalks walks,
101105
RandomWalkProbabilities randomWalkProbabilities,
102-
ProgressTracker progressTracker
106+
ProgressTracker progressTracker,
107+
TerminationFlag terminationFlag
103108
) {
104109
this.initialLearningRate = initialLearningRate;
105110
this.minLearningRate = minLearningRate;
@@ -113,6 +118,7 @@ public class Node2VecModel {
113118
this.randomWalkProbabilities = randomWalkProbabilities;
114119
this.progressTracker = progressTracker;
115120
this.randomSeed = maybeRandomSeed.orElseGet(() -> new SplittableRandom().nextLong());
121+
this.terminationFlag = terminationFlag;
116122

117123
var random = new Random();
118124
centerEmbeddings = initializeEmbeddings(toOriginalId, nodeCount, embeddingDimension, random);
@@ -140,6 +146,7 @@ Node2VecResult train() {
140146

141147
RunWithConcurrency.builder()
142148
.concurrency(concurrency)
149+
.terminationFlag(terminationFlag)
143150
.tasks(tasks)
144151
.run();
145152

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecMemoryEstimateDefinitionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void shouldEstimateMemory() {
3939

4040
MemoryEstimationAssert.assertThat(memoryEstimation)
4141
.memoryRange(1000, new Concurrency(1))
42-
.hasSameMinAndMaxEqualTo(7688456L);
42+
.hasSameMinAndMaxEqualTo(7688464L);
4343
}
4444

4545
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.neo4j.gds.core.concurrency.Concurrency;
2828
import org.neo4j.gds.core.utils.Intersections;
2929
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
30+
import org.neo4j.gds.termination.TerminatedException;
31+
import org.neo4j.gds.termination.TerminationFlag;
3032

3133
import java.util.Optional;
3234
import java.util.Random;
@@ -35,6 +37,7 @@
3537
import java.util.stream.LongStream;
3638

3739
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
3841
import static org.junit.jupiter.api.Assertions.assertEquals;
3942
import static org.mockito.ArgumentMatchers.any;
4043
import static org.mockito.ArgumentMatchers.anyLong;
@@ -83,7 +86,8 @@ void testModel() {
8386
Optional.empty(),
8487
walks,
8588
probabilitiesBuilder.build(),
86-
ProgressTracker.NULL_TRACKER
89+
ProgressTracker.NULL_TRACKER,
90+
TerminationFlag.RUNNING_TRUE
8791
);
8892

8993
var trainResult = node2VecModel.train();
@@ -186,7 +190,8 @@ void randomSeed(int iterations) {
186190
Optional.of(1337L),
187191
walks,
188192
probabilitiesBuilder.build(),
189-
ProgressTracker.NULL_TRACKER
193+
ProgressTracker.NULL_TRACKER,
194+
TerminationFlag.RUNNING_TRUE
190195
);
191196

192197
var otherNode2VecModel = new Node2VecModel(
@@ -197,7 +202,8 @@ void randomSeed(int iterations) {
197202
Optional.of(1337L),
198203
walks,
199204
probabilitiesBuilder.build(),
200-
ProgressTracker.NULL_TRACKER
205+
ProgressTracker.NULL_TRACKER,
206+
TerminationFlag.RUNNING_TRUE
201207
);
202208

203209
var embeddings = node2VecModel.train().embeddings();
@@ -231,7 +237,8 @@ void shouldCreateTrainingTasksWithCorrectRandomSeed() {
231237
Optional.of(1L), // Random Seed
232238
randomWalksMock,
233239
randomWalkProbabilitiesMock,
234-
ProgressTracker.NULL_TRACKER
240+
ProgressTracker.NULL_TRACKER,
241+
TerminationFlag.RUNNING_TRUE
235242
)
236243
);
237244

@@ -286,4 +293,57 @@ private static CompressedRandomWalks generateRandomWalks(
286293

287294
return walks;
288295
}
296+
297+
@Test
298+
void shouldRespectTerminationFlag() {
299+
var random = new Random(42);
300+
int numberOfClusters = 2;
301+
int clusterSize = 5;
302+
int numberOfWalks = 2;
303+
int walkLength = 5;
304+
305+
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
306+
numberOfClusters * clusterSize,
307+
new Concurrency(1),
308+
0.001,
309+
0.75
310+
);
311+
312+
var walks = generateRandomWalks(
313+
probabilitiesBuilder,
314+
numberOfClusters,
315+
clusterSize,
316+
numberOfWalks,
317+
walkLength,
318+
random
319+
);
320+
321+
var trainParameters = new TrainParameters(0.05, 0.0001, 10, 2, 1, 2, EmbeddingInitializer.NORMALIZED);
322+
323+
var terminationFlag = new TerminationFlag() {
324+
private int callCount = 0;
325+
@Override
326+
public boolean running() {
327+
++callCount;
328+
return callCount == 2;
329+
}
330+
};
331+
332+
var node2VecModel = new Node2VecModel(
333+
nodeId -> nodeId,
334+
1000,
335+
trainParameters,
336+
new Concurrency(4),
337+
Optional.of(19L),
338+
walks,
339+
probabilitiesBuilder.build(),
340+
ProgressTracker.NULL_TRACKER,
341+
terminationFlag
342+
);
343+
344+
assertThatExceptionOfType(TerminatedException.class)
345+
.isThrownBy(node2VecModel::train);
346+
347+
}
348+
289349
}

applications/algorithms/node-embeddings/src/test/java/org/neo4j/gds/applications/algorithms/embeddings/NodeEmbeddingAlgorithmsTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.neo4j.gds.extension.GdlGraph;
3434
import org.neo4j.gds.extension.Inject;
3535
import org.neo4j.gds.logging.GdsTestLog;
36+
import org.neo4j.gds.termination.TerminationFlag;
3637

3738
import java.util.Optional;
3839

@@ -73,7 +74,7 @@ void shouldLogProgressForNode2Vec() {
7374
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
7475
.build();
7576
var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies);
76-
var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, null);
77+
var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, TerminationFlag.RUNNING_TRUE);
7778

7879
var configuration = Node2VecStreamConfigImpl.builder().embeddingDimension(128).build();
7980

@@ -106,7 +107,7 @@ void shouldLogProgressForNode2VecWithRelationshipWeights() {
106107
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
107108
.build();
108109
var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies);
109-
var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, null);
110+
var nodeEmbeddingAlgorithms = new NodeEmbeddingAlgorithms(null, progressTrackerCreator, TerminationFlag.RUNNING_TRUE);
110111

111112
var configuration = Node2VecStreamConfigImpl.builder().embeddingDimension(128).build();
112113
nodeEmbeddingAlgorithms.node2Vec(graph, configuration);

0 commit comments

Comments
 (0)