Skip to content

Commit

Permalink
Cleaning most of @valueclass usages from algo
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Aug 20, 2024
1 parent 37553d9 commit 0b37d32
Show file tree
Hide file tree
Showing 54 changed files with 144 additions and 494 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public ApproxMaxKCutResult compute() {

progressTracker.endSubTask();

return ApproxMaxKCutResult.of(candidateSolutions[bestIdx], costs[bestIdx].get());
return new ApproxMaxKCutResult(candidateSolutions[bestIdx], costs[bestIdx].get());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,9 @@
*/
package org.neo4j.gds.approxmaxkcut;

import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeByteArray;

@ValueClass
public interface ApproxMaxKCutResult {
// Value at index `i` is the idx of the community to which node with id `i` belongs.
HugeByteArray candidateSolution();

double cutCost();

static ApproxMaxKCutResult of(
public record ApproxMaxKCutResult(
HugeByteArray candidateSolution,
double cutCost
) {
return ImmutableApproxMaxKCutResult
.builder()
.candidateSolution(candidateSolution)
.cutCost(cutCost)
.build();
}
}
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,9 @@

import java.util.function.LongToDoubleFunction;

public class BetwennessCentralityResult implements CentralityAlgorithmResult {
public record BetwennessCentralityResult(HugeAtomicDoubleArray centralities) implements CentralityAlgorithmResult{

private final HugeAtomicDoubleArray centralities;

BetwennessCentralityResult(HugeAtomicDoubleArray centralities){
this.centralities=centralities;
}
@Override
public NodePropertyValues nodePropertyValues() {
return NodePropertyValuesAdapter.adapt(centralities);
Expand All @@ -43,7 +39,4 @@ public LongToDoubleFunction centralityScoreProvider() {
return centralities::get;
}

public HugeAtomicDoubleArray centralities(){
return centralities;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
import java.util.function.LongToDoubleFunction;


public class ClosenessCentralityResult implements CentralityAlgorithmResult {
public record ClosenessCentralityResult(HugeDoubleArray centralities) implements CentralityAlgorithmResult {

private final HugeDoubleArray centralities;

ClosenessCentralityResult(HugeDoubleArray centralities) {
this.centralities = centralities;
}

@Override
public NodePropertyValues nodePropertyValues() {
return NodePropertyValuesAdapter.adapt(centralities);
Expand Down
14 changes: 6 additions & 8 deletions algo/src/main/java/org/neo4j/gds/conductance/Conductance.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.collections.hsa.HugeSparseDoubleArray;
Expand Down Expand Up @@ -190,7 +189,7 @@ private RelationshipCounts accumulateCounts(

progressTracker.endSubTask();

return ImmutableRelationshipCounts.of(internalCountsBuilder.build(), externalCountsBuilder.build());
return new RelationshipCounts(internalCountsBuilder.build(), externalCountsBuilder.build());
}

private ConductanceResult computeConductances(
Expand Down Expand Up @@ -244,7 +243,7 @@ private ConductanceResult computeConductances(

progressTracker.endSubTask();

return ConductanceResult.of(
return new ConductanceResult(
conductancesBuilder.build(),
globalConductanceSum.get() / globalValidCommunities.longValue()
);
Expand Down Expand Up @@ -316,11 +315,10 @@ HugeSparseDoubleArray externalCounts() {
}
}

@ValueClass
interface RelationshipCounts {
HugeSparseDoubleArray internalCounts();

HugeSparseDoubleArray externalCounts();
}
record RelationshipCounts(
HugeSparseDoubleArray internalCounts,
HugeSparseDoubleArray externalCounts){}


}
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,9 @@
*/
package org.neo4j.gds.conductance;

import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.hsa.HugeSparseDoubleArray;

@ValueClass
public interface ConductanceResult {
HugeSparseDoubleArray communityConductances();

double globalAverageConductance();

static ConductanceResult of(
public record ConductanceResult(
HugeSparseDoubleArray communityConductances,
double globalAverageConductance
) {
return ImmutableConductanceResult
.builder()
.communityConductances(communityConductances)
.globalAverageConductance(globalAverageConductance)
.build();
}

}
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,9 @@

import java.util.function.LongToDoubleFunction;

public class DegreeCentralityResult implements CentralityAlgorithmResult {
public record DegreeCentralityResult(long nodeCount, DegreeFunction degreeFunction) implements CentralityAlgorithmResult {

static DegreeCentralityResult EMPTY=new DegreeCentralityResult(0, v -> 0);

private final DegreeFunction degreeFunction;
private final long nodeCount;

DegreeCentralityResult(long nodeCount, DegreeFunction degreeFunction){
this.degreeFunction=degreeFunction;
this.nodeCount=nodeCount;
}

public DegreeFunction degreeFunction(){
return degreeFunction;
}
static DegreeCentralityResult EMPTY = new DegreeCentralityResult(0, v -> 0);

@Override
public NodePropertyValues nodePropertyValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,5 @@

import org.neo4j.gds.collections.ha.HugeObjectArray;

public class FastRPResult {
private final HugeObjectArray<float[]> embeddings;
public record FastRPResult(HugeObjectArray<float[]> embeddings){}

public FastRPResult(HugeObjectArray<float[]> embeddings) {
this.embeddings = embeddings;
}

public HugeObjectArray<float[]> embeddings() {
return embeddings;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.GraphSageEmbeddingsGenerator;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
Expand Down Expand Up @@ -86,6 +86,6 @@ public GraphSageResult compute() {
graph,
features
);
return GraphSageResult.of(embeddings);
return new GraphSageResult(embeddings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@
*/
package org.neo4j.gds.embeddings.graphsage.algo;

import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.embeddings.graphsage.Layer;

@ValueClass
public interface GraphSageModel {

Layer[] layers();

GraphSageTrainConfig config();
}
public record GraphSageModel(
Layer[] layers,
GraphSageTrainConfig config){}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@
*/
package org.neo4j.gds.embeddings.graphsage.algo;

import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeObjectArray;

@ValueClass
public interface GraphSageResult {
HugeObjectArray<double[]> embeddings();

static GraphSageResult of(HugeObjectArray<double[]> embeddings) {
return ImmutableGraphSageResult.of(embeddings);
}
public record GraphSageResult(
HugeObjectArray<double[]> embeddings)
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.BitSetIterator;
import org.apache.commons.math3.primes.Primes;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;

import java.util.SplittableRandom;
Expand Down Expand Up @@ -69,19 +68,13 @@ static void hashArgMin(BitSet bitSet, int[] hashes, HashGNN.MinAndArgmin result)
result.argMin = argMin;
}

@ValueClass
interface HashTriple {
record HashTriple(int a,int b,int c) {

/*
The values a, b and c represent parameters of the hash function: h(x) = x * a + b mod c,
where 0 < a, b < c and c is a prime number.
*/

int a();

int b();

int c();

static HashTriple generate(SplittableRandom rng) {
int c = Primes.nextPrime(rng.nextInt(1, Integer.MAX_VALUE));
Expand All @@ -91,7 +84,7 @@ static HashTriple generate(SplittableRandom rng) {
static HashTriple generate(SplittableRandom rng, int c) {
int a = rng.nextInt(1, c);
int b = rng.nextInt(1, c);
return ImmutableHashTriple.of(a, b, c);
return new HashTriple(a, b, c);
}

static int[] computeHashesFromTriple(int embeddingDimension, HashTriple hashTriple) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,4 @@

import org.neo4j.gds.api.properties.nodes.NodePropertyValues;

public class HashGNNResult {
private final NodePropertyValues embeddings;

public HashGNNResult(NodePropertyValues embeddings) {
this.embeddings = embeddings;
}

public NodePropertyValues embeddings() {
return embeddings;
}
}
public record HashGNNResult(NodePropertyValues embeddings) {}
19 changes: 8 additions & 11 deletions algo/src/main/java/org/neo4j/gds/embeddings/hashgnn/HashTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


import org.apache.commons.math3.primes.Primes;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
Expand All @@ -35,7 +34,7 @@

import static org.neo4j.gds.embeddings.hashgnn.HashGNNCompanion.HashTriple.computeHashesFromTriple;

class HashTask implements Runnable {
final class HashTask implements Runnable {
private static final double MAX_FINAL_INFLUENCE = 1e4;
private static final int PRIME_LOWER_BOUND = 50_000;

Expand All @@ -48,7 +47,7 @@ class HashTask implements Runnable {
private List<int[]> preAggregationHashes;
private final ProgressTracker progressTracker;

HashTask(
private HashTask(
int embeddingDimension,
double scaledNeighborInfluence,
int numberOfRelationshipTypes,
Expand Down Expand Up @@ -95,13 +94,11 @@ public static List<Hashes> compute(
return hashTasks.stream().map(HashTask::hashes).collect(Collectors.toList());
}

@ValueClass
interface Hashes {
int[] neighborsAggregationHashes();

int[] selfAggregationHashes();

List<int[]> preAggregationHashes();
record Hashes(
int[] neighborsAggregationHashes,
int[] selfAggregationHashes,
List<int[]> preAggregationHashes){

static long memoryEstimation(int ambientDimension, int numRelTypes) {
long neighborAggregation = Estimate.sizeOfIntArrayList(ambientDimension);
Expand Down Expand Up @@ -148,7 +145,7 @@ public void run() {
progressTracker.logSteps(1);
}

Hashes hashes() {
return ImmutableHashes.of(neighborsAggregationHashes, selfAggregationHashes, preAggregationHashes);
private Hashes hashes() {
return new Hashes(neighborsAggregationHashes, selfAggregationHashes, preAggregationHashes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Node2VecResult train() {
}
progressTracker.endSubTask();

return ImmutableNode2VecResult.of(centerEmbeddings, lossPerIteration);
return new Node2VecResult(centerEmbeddings, lossPerIteration);
}

private HugeObjectArray<FloatVector> initializeEmbeddings(LongUnaryOperator toOriginalNodeId, long nodeCount, int embeddingDimensions, Random random) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,10 @@
*/
package org.neo4j.gds.embeddings.node2vec;

import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.ml.core.tensor.FloatVector;

import java.util.List;

@ValueClass
public interface Node2VecResult {
HugeObjectArray<FloatVector> embeddings();

List<Double> lossPerIteration();
}
public record Node2VecResult(HugeObjectArray<FloatVector> embeddings,List<Double> lossPerIteration)
{ }
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@

import java.util.function.LongToDoubleFunction;

public class HarmonicResult implements CentralityAlgorithmResult {

private final HugeAtomicDoubleArray centralities;

HarmonicResult(HugeAtomicDoubleArray centralities) {
this.centralities = centralities;
}
public record HarmonicResult(HugeAtomicDoubleArray centralities) implements CentralityAlgorithmResult {

@Override
public NodePropertyValues nodePropertyValues() {
Expand Down
Loading

0 comments on commit 0b37d32

Please sign in to comment.