Skip to content

Commit

Permalink
#195 (#196) Add KHop UDGA implementation
Browse files Browse the repository at this point in the history
Add KHop UDGA implementation
  • Loading branch information
Leomrlin authored Oct 12, 2023
1 parent 4288afa commit 75dd00d
Show file tree
Hide file tree
Showing 9 changed files with 999 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright 2023 AntGroup CO., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

package com.antgroup.geaflow.dsl.udf.graph;

import com.antgroup.geaflow.common.type.primitive.IntegerType;
import com.antgroup.geaflow.common.type.primitive.StringType;
import com.antgroup.geaflow.dsl.common.algo.AlgorithmRuntimeContext;
import com.antgroup.geaflow.dsl.common.algo.AlgorithmUserFunction;
import com.antgroup.geaflow.dsl.common.data.RowEdge;
import com.antgroup.geaflow.dsl.common.data.RowVertex;
import com.antgroup.geaflow.dsl.common.data.impl.ObjectRow;
import com.antgroup.geaflow.dsl.common.function.Description;
import com.antgroup.geaflow.dsl.common.types.StructType;
import com.antgroup.geaflow.dsl.common.types.TableField;
import com.antgroup.geaflow.dsl.common.util.TypeCastUtil;
import com.antgroup.geaflow.model.graph.edge.EdgeDirection;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;

@Description(name = "khop", description = "built-in udga for KHop")
public class KHop implements AlgorithmUserFunction<Object, Integer> {

private static final String OUTPUT_ID = "id";
private static final String OUTPUT_K = "k";
private AlgorithmRuntimeContext<Object, Integer> context;
private Object srcId;
private int k = 1;

@Override
public void init(AlgorithmRuntimeContext<Object, Integer> context, Object[] parameters) {
this.context = context;
if (parameters.length > 2) {
throw new IllegalArgumentException(
"Only support zero or more arguments, false arguments "
+ "usage: func([alpha, [convergence, [max_iteration]]])");
}
if (parameters.length > 0) {
srcId = TypeCastUtil.cast(parameters[0], context.getGraphSchema().getIdType());
}
if (parameters.length > 1) {
k = Integer.parseInt(String.valueOf(parameters[1]));
}
}

@Override
public void process(RowVertex vertex, Iterator<Integer> messages) {
List<RowEdge> outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT));
if (context.getCurrentIterationId() == 1L) {
if (Objects.equals(srcId, vertex.getId())) {
sendMessageToNeighbors(outEdges, 1);
context.updateVertexValue(ObjectRow.create(0));
} else {
context.updateVertexValue(ObjectRow.create(Integer.MAX_VALUE));
}
} else if (context.getCurrentIterationId() <= k + 1) {
int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE);
if (messages.hasNext() && currentK == Integer.MAX_VALUE) {
Integer currK = messages.next();
context.updateVertexValue(ObjectRow.create(currK));
sendMessageToNeighbors(outEdges, currK + 1);
}
}
}

@Override
public StructType getOutputType() {
return new StructType(
new TableField(OUTPUT_ID, StringType.INSTANCE, false),
new TableField(OUTPUT_K, IntegerType.INSTANCE, false)
);
}

@Override
public void finish(RowVertex vertex) {
int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE);
if (currentK != Integer.MAX_VALUE) {
context.take(ObjectRow.create(vertex.getId(), currentK));
}
}

private void sendMessageToNeighbors(List<RowEdge> outEdges, Integer message) {
for (RowEdge rowEdge : outEdges) {
context.sendMessage(rowEdge.getTargetId(), message);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,12 @@ public void testAlgorithm_006() throws Exception {
.execute()
.checkSinkResult();
}
@Test
public void testAlgorithm_007() throws Exception {
QueryTester
.build()
.withQueryPath("/query/gql_algorithm_007.sql")
.execute()
.checkSinkResult();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1,0
3,1
5,1
6,1
4,2
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
1,3
1,5
1,6
2,3
3,4
4,1
4,6
5,4
5,6
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
1,1, ,
2,1, ,
3,1, ,
4,1, ,
5,1, ,
6,1, ,
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
set geaflow.dsl.window.size = -1;
set geaflow.dsl.ignore.exception = true;

CREATE GRAPH IF NOT EXISTS g4 (
Vertex v4 (
vid varchar ID,
vvalue int
),
Edge e4 (
srcId varchar SOURCE ID,
targetId varchar DESTINATION ID
)
) WITH (
storeType='rocksdb',
shardCount = 1
);

CREATE TABLE IF NOT EXISTS v_source (
v_id varchar,
v_value int,
ts varchar,
type varchar
) WITH (
type='file',
geaflow.dsl.file.path = 'resource:///input/test_vertex'
);

CREATE TABLE IF NOT EXISTS e_source (
src_id varchar,
dst_id varchar
) WITH (
type='file',
geaflow.dsl.file.path = 'resource:///input/test_edge'
);

CREATE TABLE IF NOT EXISTS tbl_result (
v_id varchar,
k_value int
) WITH (
type='file',
geaflow.dsl.file.path = '${target}'
);

USE GRAPH g4;

INSERT INTO g4.v4(vid, vvalue)
SELECT
v_id, v_value
FROM v_source;

INSERT INTO g4.e4(srcId, targetId)
SELECT
src_id, dst_id
FROM e_source;

CREATE Function khop AS 'com.antgroup.geaflow.dsl.udf.graph.KHop';

INSERT INTO tbl_result(v_id, k_value)
CALL khop("1",2) YIELD (vid, kValue)
RETURN vid, kValue
;
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2023 AntGroup CO., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

package com.antgroup.geaflow.example.graph.statical.compute.khop;

import com.antgroup.geaflow.api.function.io.SinkFunction;
import com.antgroup.geaflow.api.graph.compute.VertexCentricCompute;
import com.antgroup.geaflow.api.graph.function.vc.VertexCentricCombineFunction;
import com.antgroup.geaflow.api.graph.function.vc.VertexCentricComputeFunction;
import com.antgroup.geaflow.api.pdata.stream.window.PWindowSource;
import com.antgroup.geaflow.api.pdata.stream.window.PWindowStream;
import com.antgroup.geaflow.api.window.impl.AllWindow;
import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.env.Environment;
import com.antgroup.geaflow.example.config.ExampleConfigKeys;
import com.antgroup.geaflow.example.function.AbstractVcFunc;
import com.antgroup.geaflow.example.function.FileSink;
import com.antgroup.geaflow.example.function.FileSource;
import com.antgroup.geaflow.example.util.EnvironmentUtil;
import com.antgroup.geaflow.example.util.ExampleSinkFunctionFactory;
import com.antgroup.geaflow.example.util.ResultValidator;
import com.antgroup.geaflow.model.graph.edge.IEdge;
import com.antgroup.geaflow.model.graph.edge.impl.ValueEdge;
import com.antgroup.geaflow.model.graph.vertex.IVertex;
import com.antgroup.geaflow.model.graph.vertex.impl.ValueVertex;
import com.antgroup.geaflow.pipeline.IPipelineResult;
import com.antgroup.geaflow.pipeline.Pipeline;
import com.antgroup.geaflow.pipeline.PipelineFactory;
import com.antgroup.geaflow.pipeline.task.PipelineTask;
import com.antgroup.geaflow.view.GraphViewBuilder;
import com.antgroup.geaflow.view.IViewDesc.BackendType;
import com.antgroup.geaflow.view.graph.GraphViewDesc;

import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.Objects;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KHop {

private static final Logger LOGGER = LoggerFactory.getLogger(KHop.class);

public static final String RESULT_FILE_PATH = "./target/tmp/data/result/KHop";
public static final String REF_FILE_PATH = "data/reference/KHop";

private static int k = 2;
private static Object srcId = 990;

public KHop(Object inputId, int inputK) {
srcId = inputId;
k = inputK;
}

public static void main(String[] args) {
Environment environment = EnvironmentUtil.loadEnvironment(args);
submit(environment);
}

public static IPipelineResult submit(Environment environment) {
Pipeline pipeline = PipelineFactory.buildPipeline(environment);
Configuration envConfig = environment.getEnvironmentContext().getConfig();
envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH);
ResultValidator.cleanResult(RESULT_FILE_PATH);

pipeline.submit((PipelineTask) pipelineTaskCxt -> {
Configuration conf = pipelineTaskCxt.getConfig();
int sinkParallelism = conf.getInteger(ExampleConfigKeys.SINK_PARALLELISM);
PWindowSource<IVertex<Object, Integer>> vertices =
pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_vertex",
line -> {
String[] fields = line.split(",");
IVertex<Object, Integer> vertex = new ValueVertex<>(
fields[0], Integer.valueOf(fields[1]));
return Collections.singletonList(vertex);
}), AllWindow.getInstance())
.withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM));

PWindowSource<IEdge<Object, Object>> edges = pipelineTaskCxt.buildSource(new FileSource<>("data/input/email_edge",
line -> {
String[] fields = line.split(",");
IEdge<Object, Object> edge = new ValueEdge<>(fields[0], fields[1], 1);
return Collections.singletonList(edge);
}), AllWindow.getInstance())
.withParallelism(conf.getInteger(ExampleConfigKeys.SOURCE_PARALLELISM));

int iterationParallelism = conf.getInteger(ExampleConfigKeys.ITERATOR_PARALLELISM);
GraphViewDesc graphViewDesc = GraphViewBuilder
.createGraphView(GraphViewBuilder.DEFAULT_GRAPH)
.withShardNum(2)
.withBackend(BackendType.Memory)
.build();

PWindowStream<IVertex<Object, Integer>> result =
pipelineTaskCxt.buildWindowStreamGraph(vertices, edges, graphViewDesc)
.compute(new KHAlgorithms(k + 1))
.compute(iterationParallelism)
.getVertices();

SinkFunction<String> sink = ExampleSinkFunctionFactory.getSinkFunction(conf);
result.filter(v -> v.getValue() < k + 1).map(v -> String.format("%s,%s", v.getId(), v.getValue()))
.sink(sink).withParallelism(sinkParallelism);
});

return pipeline.execute();
}

public static void validateResult() throws IOException {
ResultValidator.validateResult(REF_FILE_PATH, RESULT_FILE_PATH);
}

public static class KHAlgorithms extends VertexCentricCompute<Object, Integer, Object, Integer> {

public KHAlgorithms(long iterations) {
super(iterations);
}

@Override
public VertexCentricComputeFunction<Object, Integer, Object, Integer> getComputeFunction() {
return new KHVertexCentricComputeFunction();
}

@Override
public VertexCentricCombineFunction<Integer> getCombineFunction() {
return null;
}

}

public static class KHVertexCentricComputeFunction extends AbstractVcFunc<Object, Integer, Object, Integer> {

@Override
public void compute(Object vertexId,
Iterator<Integer> messageIterator) {
IVertex<Object, Integer> vertex = this.context.vertex().get();
if (this.context.getIterationId() == 1L) {
if (Objects.equals(vertex.getId(), srcId)) {
this.context.sendMessageToNeighbors(1);
this.context.setNewVertexValue(0);
} else {
this.context.setNewVertexValue(Integer.MAX_VALUE);
}
} else {
if (vertex.getValue() == Integer.MAX_VALUE && messageIterator.hasNext()) {
int value = messageIterator.next();
this.context.sendMessageToNeighbors(value + 1);
this.context.setNewVertexValue(value);
}
}
}
}
}
Loading

0 comments on commit 75dd00d

Please sign in to comment.