Skip to content

Commit 90cec5a

Browse files
committed
Update after cherry-pick
1 parent def51f1 commit 90cec5a

File tree

2 files changed

+1
-212
lines changed

2 files changed

+1
-212
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.neo4j.gds.core.utils.partition.PartitionUtils;
2727
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2828
import org.neo4j.gds.mem.BitUtil;
29-
import org.neo4j.gds.ml.core.functions.Sigmoid;
3029
import org.neo4j.gds.ml.core.tensor.FloatVector;
3130

3231
import java.util.ArrayList;
@@ -37,8 +36,6 @@
3736
import java.util.concurrent.atomic.AtomicInteger;
3837
import java.util.function.LongUnaryOperator;
3938

40-
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.addInPlace;
41-
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.scale;
4239
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
4340

4441
public class Node2VecModel {
@@ -58,7 +55,7 @@ public class Node2VecModel {
5855
private final ProgressTracker progressTracker;
5956
private final long randomSeed;
6057

61-
private static final double EPSILON = 1e-10;
58+
static final double EPSILON = 1e-10;
6259

6360
Node2VecModel(
6461
LongUnaryOperator toOriginalId,
@@ -192,89 +189,6 @@ private HugeObjectArray<FloatVector> initializeEmbeddings(
192189
return embeddings;
193190
}
194191

195-
private static final class TrainingTask implements Runnable {
196-
private final HugeObjectArray<FloatVector> centerEmbeddings;
197-
private final HugeObjectArray<FloatVector> contextEmbeddings;
198-
199-
private final PositiveSampleProducer positiveSampleProducer;
200-
private final NegativeSampleProducer negativeSampleProducer;
201-
private final FloatVector centerGradientBuffer;
202-
private final FloatVector contextGradientBuffer;
203-
private final int negativeSamplingRate;
204-
private final float learningRate;
205-
206-
private final ProgressTracker progressTracker;
207-
208-
private double lossSum;
209-
210-
private TrainingTask(
211-
HugeObjectArray<FloatVector> centerEmbeddings,
212-
HugeObjectArray<FloatVector> contextEmbeddings,
213-
PositiveSampleProducer positiveSampleProducer,
214-
NegativeSampleProducer negativeSampleProducer,
215-
float learningRate,
216-
int negativeSamplingRate,
217-
int embeddingDimensions,
218-
ProgressTracker progressTracker
219-
) {
220-
this.centerEmbeddings = centerEmbeddings;
221-
this.contextEmbeddings = contextEmbeddings;
222-
this.positiveSampleProducer = positiveSampleProducer;
223-
this.negativeSampleProducer = negativeSampleProducer;
224-
this.learningRate = learningRate;
225-
this.negativeSamplingRate = negativeSamplingRate;
226-
227-
this.centerGradientBuffer = new FloatVector(embeddingDimensions);
228-
this.contextGradientBuffer = new FloatVector(embeddingDimensions);
229-
this.progressTracker = progressTracker;
230-
}
231-
232-
@Override
233-
public void run() {
234-
var buffer = new long[2];
235-
236-
// this corresponds to a stochastic optimizer as the embeddings are updated after each sample
237-
while (positiveSampleProducer.next(buffer)) {
238-
trainSample(buffer[0], buffer[1], true);
239-
240-
for (var i = 0; i < negativeSamplingRate; i++) {
241-
trainSample(buffer[0], negativeSampleProducer.next(), false);
242-
}
243-
progressTracker.logProgress();
244-
}
245-
}
246-
247-
private void trainSample(long center, long context, boolean positive) {
248-
var centerEmbedding = centerEmbeddings.get(center);
249-
var contextEmbedding = contextEmbeddings.get(context);
250-
251-
// L_pos = -log sigmoid(center * context) ; gradient: -sigmoid (-center * context)
252-
// L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
253-
float affinity = centerEmbedding.innerProduct(contextEmbedding);
254-
255-
//When |affinity| > 40, positiveSigmoid = 1. Double precision is not enough.
256-
//Make sure negativeSigmoid can never be 0 to avoid infinity loss.
257-
double positiveSigmoid = Sigmoid.sigmoid(affinity);
258-
double negativeSigmoid = 1 - positiveSigmoid;
259-
260-
lossSum -= positive ? Math.log(positiveSigmoid + EPSILON) : Math.log(negativeSigmoid + EPSILON);
261-
262-
float gradient = positive ? (float) -negativeSigmoid : (float) positiveSigmoid;
263-
// we are doing gradient descent, so we go in the negative direction of the gradient here
264-
float scaledGradient = -gradient * learningRate;
265-
266-
scale(contextEmbedding.data(), scaledGradient, centerGradientBuffer.data());
267-
scale(centerEmbedding.data(), scaledGradient, contextGradientBuffer.data());
268-
269-
addInPlace(centerEmbedding.data(), centerGradientBuffer.data());
270-
addInPlace(contextEmbedding.data(), contextGradientBuffer.data());
271-
}
272-
273-
double lossSum() {
274-
return lossSum;
275-
}
276-
}
277-
278192
static class FloatConsumer {
279193
float[] values;
280194
int index;

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecProgressTrackingTest.java

Lines changed: 0 additions & 125 deletions
This file was deleted.

0 commit comments

Comments
 (0)