Skip to content

Commit 00a152a

Browse files
IoannisPanagiotasvnickolov
authored andcommitted
Stick progress logging in PositiveSampleProducer
1 parent f863752 commit 00a152a

File tree

6 files changed

+502
-18
lines changed

6 files changed

+502
-18
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskInd
301301
partition -> {
302302
var taskId = taskIndex.getAndIncrement();
303303
var taskRandomSeed = randomSeed + taskId;
304-
var positiveSampleProducer = createPositiveSampleProducer(partition, taskRandomSeed);
304+
var positiveSampleProducer = createPositiveSampleProducer(partition, taskRandomSeed, progressTracker);
305305
var negativeSampleProducer = createNegativeSampleProducer(taskRandomSeed);
306306
return new TrainingTask(
307307
centerEmbeddings,
@@ -310,8 +310,7 @@ List<TrainingTask> createTrainingTasks(float learningRate, AtomicInteger taskInd
310310
negativeSampleProducer,
311311
learningRate,
312312
negativeSamplingRate,
313-
embeddingDimension,
314-
progressTracker
313+
embeddingDimension
315314
);
316315
}
317316
);
@@ -326,13 +325,15 @@ NegativeSampleProducer createNegativeSampleProducer(long randomSeed) {
326325

327326
PositiveSampleProducer createPositiveSampleProducer(
328327
DegreePartition partition,
329-
long randomSeed
328+
long randomSeed,
329+
ProgressTracker progressTracker
330330
) {
331331
return new PositiveSampleProducer(
332332
walks.iterator(partition.startNode(), partition.nodeCount()),
333333
randomWalkProbabilities.positiveSamplingProbabilities(),
334334
windowSize,
335-
randomSeed
335+
randomSeed,
336+
progressTracker
336337
);
337338
}
338339

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducer.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.neo4j.gds.embeddings.node2vec;
2121

2222
import org.neo4j.gds.collections.ha.HugeDoubleArray;
23+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2324

2425
import java.util.Iterator;
2526
import java.util.SplittableRandom;
@@ -41,18 +42,22 @@ public class PositiveSampleProducer {
4142
private int currentWindowStart;
4243
private int currentWindowEnd;
4344
private final SplittableRandom probabilitySupplier;
45+
private final ProgressTracker progressTracker;
46+
private boolean attemptedSamplingWalks = false;
4447

4548
PositiveSampleProducer(
4649
Iterator<long[]> walks,
4750
HugeDoubleArray samplingProbabilities,
4851
int windowSize,
49-
long randomSeed
52+
long randomSeed,
53+
ProgressTracker progressTracker
5054
) {
5155
this.walks = walks;
5256
this.samplingProbabilities = samplingProbabilities;
5357

5458
prefixWindowSize = ceilDiv(windowSize - 1, 2);
5559
postfixWindowSize = (windowSize - 1) / 2;
60+
this.progressTracker = progressTracker;
5661

5762
this.currentWalk = new long[0];
5863
this.centerWordIndex = -1;
@@ -71,15 +76,22 @@ public boolean next(long[] buffer) {
7176
}
7277

7378
private boolean nextWalk() {
79+
if (attemptedSamplingWalks){ //this means a walk has been exhausted
80+
progressTracker.logProgress();
81+
}
82+
attemptedSamplingWalks = true; //this is because first time nextWalk() is called, it doesnt have any walk lol
83+
7484
if (!walks.hasNext()) {
7585
return false;
7686
}
7787
long[] walk = walks.next();
88+
7889
int filteredWalkLength = filter(walk);
7990

8091
while (filteredWalkLength < 2 && walks.hasNext()) {
8192
walk = walks.next();
8293
filteredWalkLength = filter(walk);
94+
8395
}
8496

8597
if (filteredWalkLength >= 2) {
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
import org.neo4j.gds.collections.ha.HugeObjectArray;
23+
import org.neo4j.gds.ml.core.functions.Sigmoid;
24+
import org.neo4j.gds.ml.core.tensor.FloatVector;
25+
26+
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.addInPlace;
27+
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.scale;
28+
29+
final class TrainingTask implements Runnable {
30+
private final HugeObjectArray<FloatVector> centerEmbeddings;
31+
private final HugeObjectArray<FloatVector> contextEmbeddings;
32+
33+
private final PositiveSampleProducer positiveSampleProducer;
34+
private final NegativeSampleProducer negativeSampleProducer;
35+
private final FloatVector centerGradientBuffer;
36+
private final FloatVector contextGradientBuffer;
37+
private final int negativeSamplingRate;
38+
private final float learningRate;
39+
40+
private double lossSum;
41+
42+
TrainingTask(
43+
HugeObjectArray<FloatVector> centerEmbeddings,
44+
HugeObjectArray<FloatVector> contextEmbeddings,
45+
PositiveSampleProducer positiveSampleProducer,
46+
NegativeSampleProducer negativeSampleProducer,
47+
float learningRate,
48+
int negativeSamplingRate,
49+
int embeddingDimensions
50+
) {
51+
this.centerEmbeddings = centerEmbeddings;
52+
this.contextEmbeddings = contextEmbeddings;
53+
this.positiveSampleProducer = positiveSampleProducer;
54+
this.negativeSampleProducer = negativeSampleProducer;
55+
this.learningRate = learningRate;
56+
this.negativeSamplingRate = negativeSamplingRate;
57+
58+
this.centerGradientBuffer = new FloatVector(embeddingDimensions);
59+
this.contextGradientBuffer = new FloatVector(embeddingDimensions);
60+
}
61+
62+
@Override
63+
public void run() {
64+
var buffer = new long[2];
65+
66+
// this corresponds to a stochastic optimizer as the embeddings are updated after each sample
67+
while (positiveSampleProducer.next(buffer)) {
68+
trainPositiveSample(buffer[0], buffer[1]);
69+
for (var i = 0; i < negativeSamplingRate; i++) {
70+
trainNegativeSample(buffer[0], negativeSampleProducer.next());
71+
}
72+
}
73+
}
74+
75+
void trainPositiveSample(long center, long context) {
76+
var centerEmbedding = centerEmbeddings.get(center);
77+
var contextEmbedding = contextEmbeddings.get(context);
78+
79+
var scaledGradient = computePositiveGradient(centerEmbedding, contextEmbedding);
80+
81+
updateEmbeddings(
82+
centerEmbedding,
83+
contextEmbedding,
84+
scaledGradient,
85+
centerGradientBuffer,
86+
contextGradientBuffer
87+
);
88+
}
89+
90+
void trainNegativeSample(long center, long context) {
91+
var centerEmbedding = centerEmbeddings.get(center);
92+
var contextEmbedding = contextEmbeddings.get(context);
93+
94+
var scaledGradient = computeNegativeGradient(centerEmbedding, contextEmbedding);
95+
96+
updateEmbeddings(
97+
centerEmbedding,
98+
contextEmbedding,
99+
scaledGradient,
100+
centerGradientBuffer,
101+
contextGradientBuffer
102+
);
103+
}
104+
105+
float computePositiveGradient(FloatVector centerEmbedding, FloatVector contextEmbedding) {
106+
// L_pos = -log sigmoid(center * context) ; gradient: -sigmoid (-center * context)
107+
// L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
108+
float affinity = centerEmbedding.innerProduct(contextEmbedding);
109+
//When |affinity| > 40, positiveSigmoid = 1. Double precision is not enough.
110+
//Make sure negativeSigmoid can never be 0 to avoid infinity loss.
111+
double positiveSigmoid = Sigmoid.sigmoid(affinity);
112+
double negativeSigmoid = 1 - positiveSigmoid;
113+
114+
lossSum -= Math.log(positiveSigmoid + Node2VecModel.EPSILON);
115+
116+
float gradient = (float) -negativeSigmoid;
117+
// we are doing gradient descent, so we go in the negative direction of the gradient here
118+
return -gradient * learningRate;
119+
}
120+
121+
float computeNegativeGradient(FloatVector centerEmbedding, FloatVector contextEmbedding) {
122+
// L_pos = -log sigmoid(center * context) ; gradient: -sigmoid (-center * context)
123+
// L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
124+
float affinity = centerEmbedding.innerProduct(contextEmbedding);
125+
//When |affinity| > 40, positiveSigmoid = 1. Double precision is not enough.
126+
//Make sure negativeSigmoid can never be 0 to avoid infinity loss.
127+
double positiveSigmoid = Sigmoid.sigmoid(affinity);
128+
double negativeSigmoid = 1 - positiveSigmoid;
129+
130+
lossSum -= Math.log(negativeSigmoid + Node2VecModel.EPSILON);
131+
132+
float gradient = (float) positiveSigmoid;
133+
// we are doing gradient descent, so we go in the negative direction of the gradient here
134+
return -gradient * learningRate;
135+
}
136+
137+
void updateEmbeddings(
138+
FloatVector centerEmbedding,
139+
FloatVector contextEmbedding,
140+
float scaledGradient,
141+
FloatVector centerGradientBuffer,
142+
FloatVector contextGradientBuffer
143+
) {
144+
scale(contextEmbedding.data(), scaledGradient, centerGradientBuffer.data());
145+
scale(centerEmbedding.data(), scaledGradient, contextGradientBuffer.data());
146+
147+
addInPlace(centerEmbedding.data(), centerGradientBuffer.data());
148+
addInPlace(contextEmbedding.data(), contextGradientBuffer.data());
149+
}
150+
151+
double lossSum() {
152+
return lossSum;
153+
}
154+
155+
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,12 @@ void shouldCreateTrainingTasksWithCorrectRandomSeed() {
240240

241241
assertThat(trainingTasks).hasSize(5);
242242

243-
verify(node2VecModel, times(5)).createPositiveSampleProducer(any(), anyLong());
244-
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(1L));
245-
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(2L));
246-
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(3L));
247-
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(4L));
248-
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(5L));
243+
verify(node2VecModel, times(5)).createPositiveSampleProducer(any(), anyLong(),any(ProgressTracker.class));
244+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(1L),any(ProgressTracker.class));
245+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(2L),any(ProgressTracker.class));
246+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(3L),any(ProgressTracker.class));
247+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(4L),any(ProgressTracker.class));
248+
verify(node2VecModel, times(1)).createPositiveSampleProducer(any(), eq(5L),any(ProgressTracker.class));
249249

250250
verify(node2VecModel, times(5)).createNegativeSampleProducer(anyLong());
251251
verify(node2VecModel, times(1)).createNegativeSampleProducer(1L);

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducerTest.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.junit.jupiter.params.provider.Arguments;
2626
import org.junit.jupiter.params.provider.MethodSource;
2727
import org.neo4j.gds.collections.ha.HugeDoubleArray;
28+
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2829

2930
import java.util.ArrayList;
3031
import java.util.Collection;
@@ -60,7 +61,8 @@ void doesNotCauseStackOverflow() {
6061
walks.iterator(0, nbrOfWalks),
6162
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
6263
10,
63-
0
64+
0,
65+
ProgressTracker.NULL_TRACKER
6466
);
6567

6668
var counter = 0L;
@@ -88,7 +90,8 @@ void doesNotCauseStackOverflowDueToBadLuck() {
8890
walks.iterator(0, nbrOfWalks),
8991
probabilities,
9092
10,
91-
0
93+
0,
94+
ProgressTracker.NULL_TRACKER
9295
);
9396
// does not overflow the stack = passes test
9497

@@ -112,7 +115,8 @@ void doesNotAttemptToFetchOutsideBatch() {
112115
walks.iterator(0, nbrOfWalks / 2),
113116
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
114117
10,
115-
0
118+
0,
119+
ProgressTracker.NULL_TRACKER
116120
);
117121

118122
var counter = 0L;
@@ -137,7 +141,8 @@ void shouldProducePairsWith(
137141
walks.iterator(0, walks.size()),
138142
centerNodeProbabilities,
139143
windowSize,
140-
0
144+
0,
145+
ProgressTracker.NULL_TRACKER
141146
);
142147
while (producer.next(buffer)) {
143148
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -160,7 +165,8 @@ void shouldProducePairsWithBounds() {
160165
walks.iterator(0, 2),
161166
centerNodeProbabilities,
162167
3,
163-
0
168+
0,
169+
ProgressTracker.NULL_TRACKER
164170
);
165171
while (producer.next(buffer)) {
166172
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -206,7 +212,8 @@ void shouldRemoveDownsampledWordFromWalk() {
206212
walks.iterator(0, walks.size()),
207213
centerNodeProbabilities,
208214
3,
209-
0
215+
0,
216+
ProgressTracker.NULL_TRACKER
210217
);
211218

212219
while (producer.next(buffer)) {

0 commit comments

Comments
 (0)