Skip to content

Commit f863752

Browse files
IoannisPanagiotasvnickolov
authored andcommitted
Expose weakness in progress logging
1 parent ad3b361 commit f863752

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
import org.agrona.collections.MutableLong;
23+
import org.junit.jupiter.api.Test;
24+
import org.neo4j.gds.TestProgressTracker;
25+
import org.neo4j.gds.api.Graph;
26+
import org.neo4j.gds.beta.generator.RandomGraphGeneratorBuilder;
27+
import org.neo4j.gds.beta.generator.RelationshipDistribution;
28+
import org.neo4j.gds.core.concurrency.Concurrency;
29+
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
30+
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
31+
import org.neo4j.gds.logging.Log;
32+
import org.neo4j.gds.termination.TerminationFlag;
33+
34+
import java.util.List;
35+
import java.util.Optional;
36+
import java.util.concurrent.Executors;
37+
38+
import static org.assertj.core.api.Assertions.assertThat;
39+
import static org.mockito.ArgumentMatchers.anyString;
40+
import static org.mockito.Mockito.doAnswer;
41+
import static org.mockito.Mockito.mock;
42+
43+
class Node2VecProgressTrackingTest {
44+
45+
private static final List<Long> NO_SOURCE_NODES = List.of();
46+
47+
private static final Graph graph = new RandomGraphGeneratorBuilder()
48+
.nodeCount(300)
49+
.relationshipDistribution(RelationshipDistribution.RANDOM)
50+
.averageDegree(15)
51+
.seed(19L)
52+
.build()
53+
.generate();
54+
55+
56+
@Test
57+
void iterationLoggingShouldNotHang() {
58+
var concurrency = 8;
59+
var walkParameters = new SamplingWalkParameters(
60+
NO_SOURCE_NODES,
61+
50,
62+
80,
63+
2.0,
64+
0.5,
65+
0.001,
66+
0.75,
67+
1000
68+
);
69+
70+
var trainParameters = new TrainParameters(
71+
0.025,
72+
0.0001,
73+
1,
74+
10,
75+
7,
76+
128,
77+
EmbeddingInitializer.NORMALIZED
78+
);
79+
80+
81+
var parameters = new Node2VecParameters(
82+
walkParameters,
83+
trainParameters,
84+
new Concurrency(concurrency),
85+
Optional.of(1337L)
86+
);
87+
88+
var lazyMock = mock(Log.class);
89+
var iteration1reached100At = new MutableLong();
90+
doAnswer(invocation -> {
91+
var infoMessage = invocation.getArgument(0, String.class);
92+
if (infoMessage.contains("iteration 1 of 1 100%")){
93+
iteration1reached100At.set(System.currentTimeMillis());
94+
}
95+
return null;
96+
}).when(lazyMock).info(anyString());
97+
var finishedIteration1At = new MutableLong();
98+
doAnswer(invocation -> {
99+
var task = invocation.getArgument(2, String.class);
100+
var infoMessage = invocation.getArgument(3, String.class);
101+
if (task.contains("iteration 1 of 1") && infoMessage.contains("Finished")){
102+
finishedIteration1At.set(System.currentTimeMillis());
103+
}
104+
return null;
105+
}).when(lazyMock).info(anyString(),anyString(),anyString(),anyString());
106+
107+
var progressTracker = new TestProgressTracker(
108+
Node2VecTask.create(graph, parameters),
109+
new LoggerForProgressTrackingAdapter(lazyMock),
110+
new Concurrency(concurrency),
111+
EmptyTaskRegistryFactory.INSTANCE
112+
);
113+
114+
try(var ignored = Executors.newFixedThreadPool(concurrency)) {
115+
Node2Vec.create(
116+
graph,
117+
parameters,
118+
progressTracker,
119+
TerminationFlag.RUNNING_TRUE
120+
).compute().embeddings();
121+
}
122+
assertThat((finishedIteration1At.longValue() - iteration1reached100At.longValue())).isLessThanOrEqualTo(700);
123+
124+
}
125+
}

0 commit comments

Comments
 (0)