Skip to content

Commit

Permalink
Merge pull request #8137 from orazve/kge-cont-bench
Browse files Browse the repository at this point in the history
KGE continuous benchmark, scorers performance improvements
  • Loading branch information
brs96 authored Sep 18, 2023
2 parents a7f2494 + eb74298 commit c0330a9
Show file tree
Hide file tree
Showing 14 changed files with 421 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ static PropertyProducer<float[][]> randomEmbedding(String propertyName, int embe
return new RandomEmbeddingProducer(propertyName, embeddingSize, min, max);
}

static PropertyProducer<double[][]> randomEmbeddingDouble(String propertyName, int embeddingSize, double min, double max) {
return new RandomDoubleEmbeddingProducer(propertyName, embeddingSize, min, max);
}

static PropertyProducer<long[]> nodeIdAsLong(String propertyName) {
return new NodeIdProducer(propertyName);
}
Expand Down Expand Up @@ -276,6 +280,69 @@ public String toString() {
}
}

class RandomDoubleEmbeddingProducer implements PropertyProducer<double[][]> {
private final String propertyName;
private final int embeddingSize;
private final double min;
private final double max;

public RandomDoubleEmbeddingProducer(String propertyName, int embeddingSize, double min, double max) {
this.propertyName = propertyName;
this.embeddingSize = embeddingSize;
this.min = min;
this.max = max;

if (max <= min) {
throw new IllegalArgumentException("Max value must be greater than min value");
}
}

@Override
public String getPropertyName() {
return propertyName;
}

@Override
public ValueType propertyType() {
return ValueType.DOUBLE_ARRAY;
}

@Override
public void setProperty(long nodeId, double[][] embeddings, int index, Random random) {
var nodeEmbeddings = new double[embeddingSize];
for (int i = 0; i < embeddingSize; i++) {
nodeEmbeddings[i] = min + (random.nextDouble() * (max - min));
}
embeddings[index] = nodeEmbeddings;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RandomEmbeddingProducer random = (RandomEmbeddingProducer) o;
return random.embeddingSize == embeddingSize &&
Double.compare(random.min, min) == 0 &&
Double.compare(random.max, max) == 0 &&
Objects.equals(propertyName, random.propertyName);
}

@Override
public int hashCode() {
return Objects.hash(propertyName, embeddingSize, min, max);
}

@Override
public String toString() {
return "RandomDoubleProducer{" +
"propertyName='" + propertyName + '\'' +
", embeddingSize=" + embeddingSize +
", min=" + min +
", max=" + max +
'}';
}
}

class NodeIdProducer implements PropertyProducer<long[]> {
private final String propertyName;

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.SetBitsIterable;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.kge.scorers.LinkScorer;
import org.neo4j.gds.ml.kge.scorers.LinkScorerFactory;
import org.neo4j.gds.similarity.nodesim.TopKMap;
import org.neo4j.gds.utils.AutoCloseableThreadLocal;
import org.neo4j.gds.utils.CloseableThreadLocal;
Expand Down Expand Up @@ -68,7 +70,9 @@ public TopKMapComputer(
this.sourceNodes = sourceNodes;
this.targetNodes = targetNodes;
this.nodeEmbeddingProperty = nodeEmbeddingProperty;
this.relationshipTypeEmbedding = DoubleArrayList.from(relationshipTypeEmbedding.stream().mapToDouble(Double::doubleValue).toArray());
this.relationshipTypeEmbedding = DoubleArrayList.from(relationshipTypeEmbedding.stream()
.mapToDouble(Double::doubleValue)
.toArray());
this.concurrency = concurrency;
this.topK = topK;
this.scoreFunction = scoreFunction;
Expand All @@ -82,9 +86,15 @@ public KGEPredictResult compute() {

NodePropertyValues embeddings = graph.nodeProperties(nodeEmbeddingProperty);

try (var threadLocalScorer = AutoCloseableThreadLocal.withInitial(() -> LinkScorerFactory.create(scoreFunction, embeddings, relationshipTypeEmbedding))) {
try (
var threadLocalScorer = AutoCloseableThreadLocal.withInitial(() -> LinkScorerFactory.create(
scoreFunction,
embeddings,
relationshipTypeEmbedding
))
) {
//TODO maybe exploit symmetry of similarity function if available when there're many source target overlap
try (var concurrentGraph = CloseableThreadLocal.withInitial(graph::concurrentCopy)){
try (var concurrentGraph = CloseableThreadLocal.withInitial(graph::concurrentCopy)) {
ParallelUtil.parallelStreamConsume(
new SetBitsIterable(sourceNodes).stream(),
concurrency,
Expand All @@ -104,12 +114,13 @@ public KGEPredictResult compute() {
if (!Double.isNaN(similarity)) {
topKMap.put(node1, node2, similarity);
}
progressTracker.logProgress();

});
});
}
);
}
progressTracker.logProgress();
}

progressTracker.endSubTask();
Expand All @@ -127,6 +138,6 @@ private long estimateWorkload() {
}

private LongLongPredicate isCandidateLink(Graph graph) {
return (s, t) -> s != t && !graph.exists(s,t);
return (s, t) -> s != t && !graph.exists(s, t);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.kge.scorers;

import com.carrotsearch.hppc.DoubleArrayList;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;

public class DoubleDistMultLinkScorer implements LinkScorer {

private final NodePropertyValues embeddings;

private final double[] relationshipTypeEmbedding;

private long currentSourceNode;

private double[] currentCandidateTarget;


DoubleDistMultLinkScorer(NodePropertyValues embeddings, DoubleArrayList relationshipTypeEmbedding) {
this.embeddings = embeddings;
this.relationshipTypeEmbedding = relationshipTypeEmbedding.toArray();
}

@Override
public void init(long sourceNode) {
this.currentSourceNode = sourceNode;
this.currentCandidateTarget = embeddings.doubleArrayValue(currentSourceNode);
for(int i = 0; i < relationshipTypeEmbedding.length; i++){
this.currentCandidateTarget[i] *= relationshipTypeEmbedding[i];
}
}

@Override
public double computeScore(long targetNode) {
double res = 0.0;
var targetVector = embeddings.doubleArrayValue(targetNode);
for (int i = 0; i < currentCandidateTarget.length; i++) {
res += currentCandidateTarget[i] * targetVector[i];
}
return res;
}

@Override
public void close() throws Exception {}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.kge.scorers;

import com.carrotsearch.hppc.DoubleArrayList;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;

public class DoubleEuclideanDistanceLinkScorer implements LinkScorer {

NodePropertyValues embeddings;

double[] relationshipTypeEmbedding;

long currentSourceNode;

double[] currentCandidateTarget;

DoubleEuclideanDistanceLinkScorer(NodePropertyValues embeddings, DoubleArrayList relationshipTypeEmbedding) {
this.embeddings = embeddings;
this.relationshipTypeEmbedding = relationshipTypeEmbedding.toArray();
}

@Override
public void init(long sourceNode) {
this.currentSourceNode = sourceNode;
this.currentCandidateTarget = embeddings.doubleArrayValue(currentSourceNode);
for(int i = 0; i < relationshipTypeEmbedding.length; i++){
this.currentCandidateTarget[i] += relationshipTypeEmbedding[i];
}
}

@Override
public double computeScore(long targetNode) {
double res = 0.0;
var targetVector = embeddings.doubleArrayValue(targetNode);
for (int i = 0; i < currentCandidateTarget.length; i++) {
double elem = currentCandidateTarget[i] - targetVector[i];
res += elem * elem;
}
return Math.sqrt(res);
}

@Override
public void close() throws Exception { }

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,44 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.kge;
package org.neo4j.gds.ml.kge.scorers;

import com.carrotsearch.hppc.DoubleArrayList;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.ml.core.tensor.Vector;

public class DistMultLinkScorer implements LinkScorer {
public class FloatDistMultLinkScorer implements LinkScorer {

NodePropertyValues embeddings;

Vector relationshipTypeEmbedding;
double[] relationshipTypeEmbedding;

long currentSourceNode;

Vector currentCandidateTarget;
float[] currentCandidateTarget;


DistMultLinkScorer(NodePropertyValues embeddings, DoubleArrayList relationshipTypeEmbedding) {
FloatDistMultLinkScorer(NodePropertyValues embeddings, DoubleArrayList relationshipTypeEmbedding) {
this.embeddings = embeddings;
this.relationshipTypeEmbedding = new Vector(relationshipTypeEmbedding.toArray());
this.relationshipTypeEmbedding = relationshipTypeEmbedding.toArray();
}

@Override
public void init(long sourceNode) {
this.currentSourceNode = sourceNode;
this.currentCandidateTarget = new Vector(embeddings.doubleArrayValue(currentSourceNode))
.elementwiseProduct(relationshipTypeEmbedding);
this.currentCandidateTarget = embeddings.floatArrayValue(currentSourceNode);
for(int i = 0; i < relationshipTypeEmbedding.length; i++){
this.currentCandidateTarget[i] *= relationshipTypeEmbedding[i];
}
}

@Override
public double computeScore(long targetNode) {
return currentCandidateTarget.elementwiseProduct(new Vector(embeddings.doubleArrayValue(targetNode)))
.aggregateSum();
double res = 0.0;
var targetVector = embeddings.floatArrayValue(targetNode);
for (int i = 0; i < currentCandidateTarget.length; i++) {
res += currentCandidateTarget[i] * targetVector[i];
}
return res;
}

@Override
Expand Down
Loading

0 comments on commit c0330a9

Please sign in to comment.