From 0868323f5a6860841f03b80015a021d4a3f331bc Mon Sep 17 00:00:00 2001 From: JiJi <824050413@qq.com> Date: Tue, 24 Dec 2024 09:58:52 +0800 Subject: [PATCH] [ISSUE-370] add common_neighbors algorithm (#430) * add common_neighbors algorithm * expand imports --- .../function/BuildInSqlFunctionTable.java | 2 + .../dsl/udf/graph/CommonNeighbors.java | 84 +++++++++++++++++++ .../dsl/runtime/query/GQLAlgorithmTest.java | 11 +++ .../expect/gql_algorithm_common_neighbors.txt | 1 + .../query/gql_algorithm_common_neighbors.sql | 13 +++ 5 files changed, 111 insertions(+) create mode 100644 geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/udf/graph/CommonNeighbors.java create mode 100644 geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_common_neighbors.txt create mode 100644 geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_common_neighbors.sql diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java index e61e2b611..34bab4ace 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java @@ -20,6 +20,7 @@ import com.antgroup.geaflow.dsl.schema.GeaFlowFunction; import com.antgroup.geaflow.dsl.udf.graph.AllSourceShortestPath; import com.antgroup.geaflow.dsl.udf.graph.ClosenessCentrality; +import com.antgroup.geaflow.dsl.udf.graph.CommonNeighbors; import com.antgroup.geaflow.dsl.udf.graph.IncWeakConnectedComponents; import com.antgroup.geaflow.dsl.udf.graph.KCore; import com.antgroup.geaflow.dsl.udf.graph.KHop; @@ -185,6 +186,7 @@ public class BuildInSqlFunctionTable extends ListSqlOperatorTable { .add(GeaFlowFunction.of(WeakConnectedComponents.class)) .add(GeaFlowFunction.of(TriangleCount.class)) .add(GeaFlowFunction.of(IncWeakConnectedComponents.class)) + .add(GeaFlowFunction.of(CommonNeighbors.class)) .build(); public BuildInSqlFunctionTable(GQLJavaTypeFactory typeFactory) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/udf/graph/CommonNeighbors.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/udf/graph/CommonNeighbors.java new file mode 100644 index 000000000..201a6ee74 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/udf/graph/CommonNeighbors.java @@ -0,0 +1,84 @@ +package com.antgroup.geaflow.dsl.udf.graph; + +import com.antgroup.geaflow.common.tuple.Tuple; +import com.antgroup.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import com.antgroup.geaflow.dsl.common.algo.AlgorithmUserFunction; +import com.antgroup.geaflow.dsl.common.data.Row; +import com.antgroup.geaflow.dsl.common.data.RowEdge; +import com.antgroup.geaflow.dsl.common.data.RowVertex; +import com.antgroup.geaflow.dsl.common.data.impl.ObjectRow; +import com.antgroup.geaflow.dsl.common.function.Description; +import com.antgroup.geaflow.dsl.common.types.GraphSchema; +import com.antgroup.geaflow.dsl.common.types.StructType; +import com.antgroup.geaflow.dsl.common.types.TableField; +import com.antgroup.geaflow.dsl.common.util.TypeCastUtil; +import com.antgroup.geaflow.model.graph.edge.EdgeDirection; + +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +@Description(name = "common_neighbors", description = "built-in udga for CommonNeighbors") +public class CommonNeighbors implements AlgorithmUserFunction { + + private AlgorithmRuntimeContext context; + + // tuple to store params + private Tuple vertices; + + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; + + if (params.length != 2) { + throw new IllegalArgumentException("Only support two arguments, usage: common_neighbors(id_a, id_b)"); + } + this.vertices = new Tuple<>( + TypeCastUtil.cast(params[0], context.getGraphSchema().getIdType()), + TypeCastUtil.cast(params[1], context.getGraphSchema().getIdType()) + ); + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (context.getCurrentIterationId() == 1L) { + // send message to neighbors if they are vertices in params + if (vertices.f0.equals(vertex.getId()) || vertices.f1.equals(vertex.getId())) { + sendMessageToNeighbors(context.loadEdges(EdgeDirection.BOTH), vertex.getId()); + } + } else if (context.getCurrentIterationId() == 2L) { + // add to result if received messages from both vertices in params + Tuple received = new Tuple<>(false, false); + while (messages.hasNext()) { + Object message = messages.next(); + if (vertices.f0.equals(message)) { + received.setF0(true); + } + if (vertices.f1.equals(message)) { + received.setF1(true); + } + + if (received.getF0() && received.getF1()) { + context.take(ObjectRow.create(vertex.getId())); + } + } + } + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false) + ); + } + + private void sendMessageToNeighbors(List edges, Object message) { + for (RowEdge rowEdge : edges) { + context.sendMessage(rowEdge.getTargetId(), message); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/com/antgroup/geaflow/dsl/runtime/query/GQLAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/com/antgroup/geaflow/dsl/runtime/query/GQLAlgorithmTest.java index 7249d00cd..2abdaa289 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/com/antgroup/geaflow/dsl/runtime/query/GQLAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/com/antgroup/geaflow/dsl/runtime/query/GQLAlgorithmTest.java @@ -191,6 +191,17 @@ public void testIncGraphAlgorithm_assp() throws Exception { .checkSinkResult(); } + + @Test + public void testAlgorithmCommonNeighbors() throws Exception { + QueryTester + .build() + .withGraphDefine("/query/modern_graph.sql") + .withQueryPath("/query/gql_algorithm_common_neighbors.sql") + .execute() + .checkSinkResult(); + } + private void clearGraph() throws IOException { File file = new File(TEST_GRAPH_PATH); if (file.exists()) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_common_neighbors.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_common_neighbors.txt new file mode 100644 index 000000000..bf0d87ab1 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_common_neighbors.txt @@ -0,0 +1 @@ +4 \ No newline at end of file diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_common_neighbors.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_common_neighbors.sql new file mode 100644 index 000000000..9ba8dded8 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_common_neighbors.sql @@ -0,0 +1,13 @@ +CREATE TABLE result_tb ( + vid int +) WITH ( + type='file', + geaflow.dsl.file.path='${target}' +); + +USE GRAPH modern; + +INSERT INTO result_tb +CALL common_neighbors(1, 3) YIELD (id) +RETURN cast (id as int) +; \ No newline at end of file