Skip to content

Commit

Permalink
Rewrite tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IoannisPanagiotas committed Oct 17, 2023
1 parent 2df1366 commit e0b80d2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@
import org.neo4j.gds.catalog.GraphProjectProc;
import org.neo4j.gds.catalog.GraphStreamNodePropertiesProc;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.Neo4jGraph;

import java.util.HashMap;
import java.util.HashSet;
import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
Expand Down Expand Up @@ -69,6 +73,9 @@ class LeidenMutateProcTest extends BaseProcTest {
" (a5)-[:R {weight: 1.0}]->(a7)," +
" (a6)-[:R {weight: 1.0}]->(a7)";

@Inject
IdFunction idFunction;

@BeforeEach
void setUp() throws Exception {
registerProcedures(
Expand All @@ -84,16 +91,36 @@ void setUp() throws Exception {
@ParameterizedTest
@ValueSource(strings = {"gds.leiden","gds.beta.leiden"})
void mutate(String procedureName) {

var query = "CALL " + procedureName + ".mutate('leiden', {mutateProperty: 'communityId', concurrency: 1})";
assertLeidenMutateQuery(query);

Graph mutatedGraph = GraphStoreCatalog.get(getUsername(), DatabaseId.of(db.databaseName()), "leiden").graphStore().getUnion();

var communities = mutatedGraph.nodeProperties("communityId");
var communitySet = new HashSet<Long>();
HashMap<Long, Long> communitiesSet = new HashMap<>();

mutatedGraph.forEachNode(nodeId -> {
communitySet.add(communities.longValue(nodeId));
var community = communities.longValue(nodeId);
var neo4jId = mutatedGraph.toOriginalNodeId(nodeId);
communitiesSet.put(neo4jId, community);
return true;
});
assertThat(communitySet).containsExactly(3L, 6L);

Function<String, Long> map = node -> communitiesSet.get(idFunction.of(node));

//community 1
assertThat(map.apply("a0"))
.isEqualTo(map.apply("a2"))
.isEqualTo(map.apply("a3"))
.isEqualTo(map.apply("a4"))
.isNotEqualTo(map.apply("a1"));

//community 2
assertThat(map.apply("a1"))
.isEqualTo(map.apply("a5"))
.isEqualTo(map.apply("a6"))
.isEqualTo(map.apply("a7"));

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.catalog.GraphProjectProc;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.extension.Neo4jGraph;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -73,6 +77,10 @@ class LeidenWriteProcTest extends BaseProcTest {
" (a5)-[:R {weight: 1.0}]->(a7)," +
" (a6)-[:R {weight: 1.0}]->(a7)";

@Inject
IdFunction idFunction;


@BeforeEach
void setUp() throws Exception {
registerProcedures(
Expand Down Expand Up @@ -100,16 +108,36 @@ void write(String procedureName) {


var writeGraph = GraphStoreCatalog.get(getUsername(), DatabaseId.of(db.databaseName()), "writeGraph").graphStore().getUnion();

var communities = writeGraph.nodeProperties("communityId");
var communitySet = new HashSet<Long>();

HashMap<Long, Long> communitiesSet = new HashMap<>();

writeGraph.forEachNode(nodeId -> {
communitySet.add(communities.longValue(nodeId));
var community = communities.longValue(nodeId);
var neo4jId = writeGraph.toOriginalNodeId(nodeId);
communitiesSet.put(neo4jId, community);
return true;
});
assertThat(communitySet).containsExactly(3L, 6L);

Function<String, Long> map = node -> communitiesSet.get(idFunction.of(node));

//community 1
assertThat(map.apply("a0"))
.isEqualTo(map.apply("a2"))
.isEqualTo(map.apply("a3"))
.isEqualTo(map.apply("a4"))
.isNotEqualTo(map.apply("a1"));

//community 2
assertThat(map.apply("a1"))
.isEqualTo(map.apply("a5"))
.isEqualTo(map.apply("a6"))
.isEqualTo(map.apply("a7"));

}


@Test
void shouldWriteWithConsecutiveIds() {
var query = "CALL gds.leiden.write('leiden', { writeProperty: 'communityId', consecutiveIds: true })";
Expand Down

0 comments on commit e0b80d2

Please sign in to comment.