26
26
import org .neo4j .gds .core .utils .partition .PartitionUtils ;
27
27
import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
28
28
import org .neo4j .gds .mem .BitUtil ;
29
- import org .neo4j .gds .ml .core .functions .Sigmoid ;
30
29
import org .neo4j .gds .ml .core .tensor .FloatVector ;
31
30
32
31
import java .util .ArrayList ;
37
36
import java .util .concurrent .atomic .AtomicInteger ;
38
37
import java .util .function .LongUnaryOperator ;
39
38
40
- import static org .neo4j .gds .ml .core .tensor .operations .FloatVectorOperations .addInPlace ;
41
- import static org .neo4j .gds .ml .core .tensor .operations .FloatVectorOperations .scale ;
42
39
import static org .neo4j .gds .utils .StringFormatting .formatWithLocale ;
43
40
44
41
public class Node2VecModel {
@@ -58,7 +55,7 @@ public class Node2VecModel {
58
55
private final ProgressTracker progressTracker ;
59
56
private final long randomSeed ;
60
57
61
- private static final double EPSILON = 1e-10 ;
58
+ static final double EPSILON = 1e-10 ;
62
59
63
60
Node2VecModel (
64
61
LongUnaryOperator toOriginalId ,
@@ -192,89 +189,6 @@ private HugeObjectArray<FloatVector> initializeEmbeddings(
192
189
return embeddings ;
193
190
}
194
191
195
- private static final class TrainingTask implements Runnable {
196
- private final HugeObjectArray <FloatVector > centerEmbeddings ;
197
- private final HugeObjectArray <FloatVector > contextEmbeddings ;
198
-
199
- private final PositiveSampleProducer positiveSampleProducer ;
200
- private final NegativeSampleProducer negativeSampleProducer ;
201
- private final FloatVector centerGradientBuffer ;
202
- private final FloatVector contextGradientBuffer ;
203
- private final int negativeSamplingRate ;
204
- private final float learningRate ;
205
-
206
- private final ProgressTracker progressTracker ;
207
-
208
- private double lossSum ;
209
-
210
- private TrainingTask (
211
- HugeObjectArray <FloatVector > centerEmbeddings ,
212
- HugeObjectArray <FloatVector > contextEmbeddings ,
213
- PositiveSampleProducer positiveSampleProducer ,
214
- NegativeSampleProducer negativeSampleProducer ,
215
- float learningRate ,
216
- int negativeSamplingRate ,
217
- int embeddingDimensions ,
218
- ProgressTracker progressTracker
219
- ) {
220
- this .centerEmbeddings = centerEmbeddings ;
221
- this .contextEmbeddings = contextEmbeddings ;
222
- this .positiveSampleProducer = positiveSampleProducer ;
223
- this .negativeSampleProducer = negativeSampleProducer ;
224
- this .learningRate = learningRate ;
225
- this .negativeSamplingRate = negativeSamplingRate ;
226
-
227
- this .centerGradientBuffer = new FloatVector (embeddingDimensions );
228
- this .contextGradientBuffer = new FloatVector (embeddingDimensions );
229
- this .progressTracker = progressTracker ;
230
- }
231
-
232
- @ Override
233
- public void run () {
234
- var buffer = new long [2 ];
235
-
236
- // this corresponds to a stochastic optimizer as the embeddings are updated after each sample
237
- while (positiveSampleProducer .next (buffer )) {
238
- trainSample (buffer [0 ], buffer [1 ], true );
239
-
240
- for (var i = 0 ; i < negativeSamplingRate ; i ++) {
241
- trainSample (buffer [0 ], negativeSampleProducer .next (), false );
242
- }
243
- progressTracker .logProgress ();
244
- }
245
- }
246
-
247
- private void trainSample (long center , long context , boolean positive ) {
248
- var centerEmbedding = centerEmbeddings .get (center );
249
- var contextEmbedding = contextEmbeddings .get (context );
250
-
251
- // L_pos = -log sigmoid(center * context) ; gradient: -sigmoid (-center * context)
252
- // L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
253
- float affinity = centerEmbedding .innerProduct (contextEmbedding );
254
-
255
- //When |affinity| > 40, positiveSigmoid = 1. Double precision is not enough.
256
- //Make sure negativeSigmoid can never be 0 to avoid infinity loss.
257
- double positiveSigmoid = Sigmoid .sigmoid (affinity );
258
- double negativeSigmoid = 1 - positiveSigmoid ;
259
-
260
- lossSum -= positive ? Math .log (positiveSigmoid + EPSILON ) : Math .log (negativeSigmoid + EPSILON );
261
-
262
- float gradient = positive ? (float ) -negativeSigmoid : (float ) positiveSigmoid ;
263
- // we are doing gradient descent, so we go in the negative direction of the gradient here
264
- float scaledGradient = -gradient * learningRate ;
265
-
266
- scale (contextEmbedding .data (), scaledGradient , centerGradientBuffer .data ());
267
- scale (centerEmbedding .data (), scaledGradient , contextGradientBuffer .data ());
268
-
269
- addInPlace (centerEmbedding .data (), centerGradientBuffer .data ());
270
- addInPlace (contextEmbedding .data (), contextGradientBuffer .data ());
271
- }
272
-
273
- double lossSum () {
274
- return lossSum ;
275
- }
276
- }
277
-
278
192
static class FloatConsumer {
279
193
float [] values ;
280
194
int index ;
0 commit comments