-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Add KHop UDGA implementation
- Loading branch information
Showing
9 changed files
with
999 additions
and
0 deletions.
There are no files selected for viewing
101 changes: 101 additions & 0 deletions
101
...w/geaflow-dsl/geaflow-dsl-plan/src/main/java/com/antgroup/geaflow/dsl/udf/graph/KHop.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/* | ||
* Copyright 2023 AntGroup CO., Ltd. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
*/ | ||
|
||
package com.antgroup.geaflow.dsl.udf.graph; | ||
|
||
import com.antgroup.geaflow.common.type.primitive.IntegerType; | ||
import com.antgroup.geaflow.common.type.primitive.StringType; | ||
import com.antgroup.geaflow.dsl.common.algo.AlgorithmRuntimeContext; | ||
import com.antgroup.geaflow.dsl.common.algo.AlgorithmUserFunction; | ||
import com.antgroup.geaflow.dsl.common.data.RowEdge; | ||
import com.antgroup.geaflow.dsl.common.data.RowVertex; | ||
import com.antgroup.geaflow.dsl.common.data.impl.ObjectRow; | ||
import com.antgroup.geaflow.dsl.common.function.Description; | ||
import com.antgroup.geaflow.dsl.common.types.StructType; | ||
import com.antgroup.geaflow.dsl.common.types.TableField; | ||
import com.antgroup.geaflow.dsl.common.util.TypeCastUtil; | ||
import com.antgroup.geaflow.model.graph.edge.EdgeDirection; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Iterator; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
@Description(name = "khop", description = "built-in udga for KHop") | ||
public class KHop implements AlgorithmUserFunction<Object, Integer> { | ||
|
||
private static final String OUTPUT_ID = "id"; | ||
private static final String OUTPUT_K = "k"; | ||
private AlgorithmRuntimeContext<Object, Integer> context; | ||
private Object srcId; | ||
private int k = 1; | ||
|
||
@Override | ||
public void init(AlgorithmRuntimeContext<Object, Integer> context, Object[] parameters) { | ||
this.context = context; | ||
if (parameters.length > 2) { | ||
throw new IllegalArgumentException( | ||
"Only support zero or more arguments, false arguments " | ||
+ "usage: func([alpha, [convergence, [max_iteration]]])"); | ||
} | ||
if (parameters.length > 0) { | ||
srcId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType()); | ||
} | ||
if (parameters.length > 1) { | ||
k = Integer.parseInt(String.valueOf(parameters[1])); | ||
} | ||
} | ||
|
||
@Override | ||
public void process(RowVertex vertex, Iterator<Integer> messages) { | ||
List<RowEdge> outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT)); | ||
if (context.getCurrentIterationId() == 1L) { | ||
if (Objects.equals(srcId, vertex.getId())) { | ||
sendMessageToNeighbors(outEdges, 1); | ||
context.updateVertexValue(ObjectRow.create(0)); | ||
} else { | ||
context.updateVertexValue(ObjectRow.create(Integer.MAX_VALUE)); | ||
} | ||
} else if (context.getCurrentIterationId() <= k + 1) { | ||
int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); | ||
if (messages.hasNext() && currentK == Integer.MAX_VALUE) { | ||
Integer currK = messages.next(); | ||
context.updateVertexValue(ObjectRow.create(currK)); | ||
sendMessageToNeighbors(outEdges, currK + 1); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public StructType getOutputType() { | ||
return new StructType( | ||
new TableField(OUTPUT_ID, StringType.INSTANCE, false), | ||
new TableField(OUTPUT_K, IntegerType.INSTANCE, false) | ||
); | ||
} | ||
|
||
@Override | ||
public void finish(RowVertex vertex) { | ||
int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); | ||
if (currentK != Integer.MAX_VALUE) { | ||
context.take(ObjectRow.create(vertex.getId(), currentK)); | ||
} | ||
} | ||
|
||
private void sendMessageToNeighbors(List<RowEdge> outEdges, Integer message) { | ||
for (RowEdge rowEdge : outEdges) { | ||
context.sendMessage(rowEdge.getTargetId(), message); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_007.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
1,0 | ||
3,1 | ||
5,1 | ||
6,1 | ||
4,2 |
9 changes: 9 additions & 0 deletions
9
geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/input/test_edge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
1,3 | ||
1,5 | ||
1,6 | ||
2,3 | ||
3,4 | ||
4,1 | ||
4,6 | ||
5,4 | ||
5,6 |
6 changes: 6 additions & 0 deletions
6
geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/input/test_vertex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
1,1, , | ||
2,1, , | ||
3,1, , | ||
4,1, , | ||
5,1, , | ||
6,1, , |
61 changes: 61 additions & 0 deletions
61
geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_007.sql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
set geaflow.dsl.window.size = -1; | ||
set geaflow.dsl.ignore.exception = true; | ||
|
||
CREATE GRAPH IF NOT EXISTS g4 ( | ||
Vertex v4 ( | ||
vid varchar ID, | ||
vvalue int | ||
), | ||
Edge e4 ( | ||
srcId varchar SOURCE ID, | ||
targetId varchar DESTINATION ID | ||
) | ||
) WITH ( | ||
storeType='rocksdb', | ||
shardCount = 1 | ||
); | ||
|
||
CREATE TABLE IF NOT EXISTS v_source ( | ||
v_id varchar, | ||
v_value int, | ||
ts varchar, | ||
type varchar | ||
) WITH ( | ||
type='file', | ||
geaflow.dsl.file.path = 'resource:///input/test_vertex' | ||
); | ||
|
||
CREATE TABLE IF NOT EXISTS e_source ( | ||
src_id varchar, | ||
dst_id varchar | ||
) WITH ( | ||
type='file', | ||
geaflow.dsl.file.path = 'resource:///input/test_edge' | ||
); | ||
|
||
CREATE TABLE IF NOT EXISTS tbl_result ( | ||
v_id varchar, | ||
k_value int | ||
) WITH ( | ||
type='file', | ||
geaflow.dsl.file.path = '${target}' | ||
); | ||
|
||
USE GRAPH g4; | ||
|
||
INSERT INTO g4.v4(vid, vvalue) | ||
SELECT | ||
v_id, v_value | ||
FROM v_source; | ||
|
||
INSERT INTO g4.e4(srcId, targetId) | ||
SELECT | ||
src_id, dst_id | ||
FROM e_source; | ||
|
||
CREATE Function khop AS 'com.antgroup.geaflow.dsl.udf.graph.KHop'; | ||
|
||
INSERT INTO tbl_result(v_id, k_value) | ||
CALL khop("1",2) YIELD (vid, kValue) | ||
RETURN vid, kValue | ||
; |
165 changes: 165 additions & 0 deletions
165
...examples/src/main/java/com/antgroup/geaflow/example/graph/statical/compute/khop/KHop.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
/* | ||
* Copyright 2023 AntGroup CO., Ltd. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
*/ | ||
|
||
package com.antgroup.geaflow.example.graph.statical.compute.khop; | ||
|
||
import com.antgroup.geaflow.api.function.io.SinkFunction; | ||
import com.antgroup.geaflow.api.graph.compute.VertexCentricCompute; | ||
import com.antgroup.geaflow.api.graph.function.vc.VertexCentricCombineFunction; | ||
import com.antgroup.geaflow.api.graph.function.vc.VertexCentricComputeFunction; | ||
import com.antgroup.geaflow.api.pdata.stream.window.PWindowSource; | ||
import com.antgroup.geaflow.api.pdata.stream.window.PWindowStream; | ||
import com.antgroup.geaflow.api.window.impl.AllWindow; | ||
import com.antgroup.geaflow.common.config.Configuration; | ||
import com.antgroup.geaflow.env.Environment; | ||
import com.antgroup.geaflow.example.config.ExampleConfigKeys; | ||
import com.antgroup.geaflow.example.function.AbstractVcFunc; | ||
import com.antgroup.geaflow.example.function.FileSink; | ||
import com.antgroup.geaflow.example.function.FileSource; | ||
import com.antgroup.geaflow.example.util.EnvironmentUtil; | ||
import com.antgroup.geaflow.example.util.ExampleSinkFunctionFactory; | ||
import com.antgroup.geaflow.example.util.ResultValidator; | ||
import com.antgroup.geaflow.model.graph.edge.IEdge; | ||
import com.antgroup.geaflow.model.graph.edge.impl.ValueEdge; | ||
import com.antgroup.geaflow.model.graph.vertex.IVertex; | ||
import com.antgroup.geaflow.model.graph.vertex.impl.ValueVertex; | ||
import com.antgroup.geaflow.pipeline.IPipelineResult; | ||
import com.antgroup.geaflow.pipeline.Pipeline; | ||
import com.antgroup.geaflow.pipeline.PipelineFactory; | ||
import com.antgroup.geaflow.pipeline.task.PipelineTask; | ||
import com.antgroup.geaflow.view.GraphViewBuilder; | ||
import com.antgroup.geaflow.view.IViewDesc.BackendType; | ||
import com.antgroup.geaflow.view.graph.GraphViewDesc; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.Iterator; | ||
import java.util.Objects; | ||
|
||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
public class KHop { | ||
|
||
private static final Logger LOGGER = LoggerFactory.getLogger(KHop.class); | ||
|
||
public static final String RESULT_FILE_PATH = "./target/tmp/data/result/KHop"; | ||
public static final String REF_FILE_PATH = "data/reference/KHop"; | ||
|
||
private static int k = 2; | ||
private static Object srcId = 990; | ||
|
||
public KHop(Object inputId, int inputK) { | ||
srcId = inputId; | ||
k = inputK; | ||
} | ||
|
||
public static void main(String[] args) { | ||
Environment environment = EnvironmentUtil.loadEnvironment(args); | ||
submit(environment); | ||
} | ||
|
||
public static IPipelineResult submit(Environment environment) { | ||
Pipeline pipeline = PipelineFactory.buildPipeline(environment); | ||
Configuration envConfig = environment.getEnvironmentContext().getConfig(); | ||
envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); | ||
ResultValidator.cleanResult(RESULT_FILE_PATH); | ||
|
||
pipeline.submit((PipelineTask) pipelineTaskCxt -> { | ||
Configuration conf = pipelineTaskCxt.getConfig(); | ||
int sinkParallelism = conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM); | ||
PWindowSource<IVertex<Object, Integer>> vertices = | ||
pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex", | ||
line -> { | ||
String[] fields = line.split(","); | ||
IVertex<Object, Integer> vertex = new ValueVertex<>( | ||
fields[0], Integer.valueOf(fields[1])); | ||
return Collections.singletonList(vertex); | ||
}), AllWindow.getInstance()) | ||
.withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); | ||
|
||
PWindowSource<IEdge<Object, Object>> edges = pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge", | ||
line -> { | ||
String[] fields = line.split(","); | ||
IEdge<Object, Object> edge = new ValueEdge<>(fields[0], fields[1], 1); | ||
return Collections.singletonList(edge); | ||
}), AllWindow.getInstance()) | ||
.withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM)); | ||
|
||
int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM); | ||
GraphViewDesc graphViewDesc = GraphViewBuilder | ||
.createGraphView(GraphViewBuilder.DEFAULT_GRAPH) | ||
.withShardNum(2) | ||
.withBackend(BackendType.Memory) | ||
.build(); | ||
|
||
PWindowStream<IVertex<Object, Integer>> result = | ||
pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc) | ||
.compute(new KHAlgorithms(k + 1)) | ||
.compute(iterationParallelism) | ||
.getVertices(); | ||
|
||
SinkFunction<String> sink = ExampleSinkFunctionFactory.getSinkFunction(conf); | ||
result.filter(v -> v.getValue() < k + 1).map(v -> String.format("%s,%s", v.getId(), v.getValue())) | ||
.sink(sink).withParallelism(sinkParallelism); | ||
}); | ||
|
||
return pipeline.execute(); | ||
} | ||
|
||
public static void validateResult() throws IOException { | ||
ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH); | ||
} | ||
|
||
public static class KHAlgorithms extends VertexCentricCompute<Object, Integer, Object, Integer> { | ||
|
||
public KHAlgorithms(long iterations) { | ||
super(iterations); | ||
} | ||
|
||
@Override | ||
public VertexCentricComputeFunction<Object, Integer, Object, Integer> getComputeFunction() { | ||
return new KHVertexCentricComputeFunction(); | ||
} | ||
|
||
@Override | ||
public VertexCentricCombineFunction<Integer> getCombineFunction() { | ||
return null; | ||
} | ||
|
||
} | ||
|
||
public static class KHVertexCentricComputeFunction extends AbstractVcFunc<Object, Integer, Object, Integer> { | ||
|
||
@Override | ||
public void compute(Object vertexId, | ||
Iterator<Integer> messageIterator) { | ||
IVertex<Object, Integer> vertex = this.context.vertex().get(); | ||
if (this.context.getIterationId() == 1L) { | ||
if (Objects.equals(vertex.getId(), srcId)) { | ||
this.context.sendMessageToNeighbors(1); | ||
this.context.setNewVertexValue(0); | ||
} else { | ||
this.context.setNewVertexValue(Integer.MAX_VALUE); | ||
} | ||
} else { | ||
if (vertex.getValue() == Integer.MAX_VALUE && messageIterator.hasNext()) { | ||
int value = messageIterator.next(); | ||
this.context.sendMessageToNeighbors(value + 1); | ||
this.context.setNewVertexValue(value); | ||
} | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.