Skip to content

Commit

Permalink
Merge pull request #8238 from vnickolov/leiden-write-facade
Browse files Browse the repository at this point in the history
Migrate Leiden write to facade
  • Loading branch information
vnickolov authored Oct 11, 2023
2 parents 64badb6 + 6f7959b commit 8ecd00b
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import org.neo4j.gds.algorithms.CommunityStatisticsSpecificFields;
import org.neo4j.gds.algorithms.K1ColoringSpecificFields;
import org.neo4j.gds.algorithms.KCoreSpecificFields;
import org.neo4j.gds.algorithms.LouvainSpecificFields;
import org.neo4j.gds.algorithms.KmeansSpecificFields;
import org.neo4j.gds.algorithms.LeidenSpecificFields;
import org.neo4j.gds.algorithms.LouvainSpecificFields;
import org.neo4j.gds.algorithms.NodePropertyWriteResult;
import org.neo4j.gds.algorithms.StandardCommunityStatisticsSpecificFields;
import org.neo4j.gds.algorithms.TriangleCountSpecificFields;
Expand All @@ -39,6 +40,8 @@
import org.neo4j.gds.kcore.KCoreDecompositionWriteConfig;
import org.neo4j.gds.kmeans.KmeansResult;
import org.neo4j.gds.kmeans.KmeansWriteConfig;
import org.neo4j.gds.leiden.LeidenResult;
import org.neo4j.gds.leiden.LeidenWriteConfig;
import org.neo4j.gds.louvain.LouvainResult;
import org.neo4j.gds.louvain.LouvainWriteConfig;
import org.neo4j.gds.result.CommunityStatistics;
Expand All @@ -53,10 +56,9 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.neo4j.gds.algorithms.community.CommunityResultCompanion.createIntermediateCommunitiesNodePropertyValues;

import static org.neo4j.gds.algorithms.community.AlgorithmRunner.runWithTiming;
import static org.neo4j.gds.algorithms.community.CommunityResultCompanion.arrayMatrixToListMatrix;
import static org.neo4j.gds.algorithms.community.CommunityResultCompanion.createIntermediateCommunitiesNodePropertyValues;

public class CommunityAlgorithmsWriteBusinessFacade {

Expand Down Expand Up @@ -275,6 +277,64 @@ public NodePropertyWriteResult<LouvainSpecificFields> louvain(
);
}

public NodePropertyWriteResult<LeidenSpecificFields> leiden(
String graphName,
LeidenWriteConfig configuration,
User user,
DatabaseId databaseId,
StatisticsComputationInstructions statisticsComputationInstructions
) {
// 1. Run the algorithm and time the execution
var intermediateResult = AlgorithmRunner.runWithTiming(
() -> communityAlgorithmsFacade.leiden(graphName, configuration, user, databaseId)
);
var algorithmResult = intermediateResult.algorithmResult;

NodePropertyValuesMapper<LeidenResult, LeidenWriteConfig> mapper = ((result, config) -> {
return config.includeIntermediateCommunities()
? createIntermediateCommunitiesNodePropertyValues(
result::getIntermediateCommunities,
result.communities().size()
)
: CommunityResultCompanion.nodePropertyValues(
config.isIncremental(),
config.writeProperty(),
config.seedProperty(),
config.consecutiveIds(),
NodePropertyValuesAdapter.adapt(result.dendrogramManager().getCurrent()),
config.minCommunitySize(),
config.concurrency(),
() -> algorithmResult.graphStore().nodeProperty(config.seedProperty())
);
});

return writeToDatabase(
algorithmResult,
configuration,
mapper,
(result -> result.communities()::get),
(result, componentCount, communitySummary) -> {
return LeidenSpecificFields.from(
result.communities().size(),
result.modularity(),
result.modularities(),
componentCount,
result.ranLevels(),
result.didConverge(),
communitySummary
);
},
statisticsComputationInstructions,
intermediateResult.computeMilliseconds,
() -> LeidenSpecificFields.EMPTY,
"LeidenWrite",
configuration.writeConcurrency(),
configuration.writeProperty(),
configuration.arrowConnectionInfo()
);
}


public NodePropertyWriteResult<KmeansSpecificFields> kmeans(
String graphName,
KmeansWriteConfig configuration,
Expand Down Expand Up @@ -495,6 +555,4 @@ <RESULT, CONFIG extends AlgoBaseConfig, ASF> NodePropertyWriteResult<ASF> writeT

}



}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
package org.neo4j.gds.leiden;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.write.NodePropertyExporterBuilder;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.MemoryEstimationExecutor;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.procedures.GraphDataScience;
import org.neo4j.gds.procedures.community.leiden.LeidenWriteResult;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
Expand All @@ -39,19 +37,17 @@
import static org.neo4j.procedure.Mode.WRITE;

public class LeidenWriteProc extends BaseProc {

@Context
public NodePropertyExporterBuilder nodePropertyExporterBuilder;
public GraphDataScience facade;

@Procedure(value = "gds.leiden.write", mode = WRITE)
@Description(DESCRIPTION)
public Stream<WriteResult> write(
public Stream<LeidenWriteResult> write(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
) {
return new ProcedureExecutor<>(
new LeidenWriteSpec(),
executionContext()
).compute(graphName, configuration);
return facade.community().leidenWrite(graphName, configuration);
}

@Procedure(value = "gds.leiden.write.estimate", mode = READ)
Expand All @@ -60,18 +56,14 @@ public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphName,
@Name(value = "algoConfiguration") Map<String, Object> configuration
) {
return new MemoryEstimationExecutor<>(
new LeidenWriteSpec(),
executionContext(),
transactionContext()
).computeEstimate(graphName, configuration);
return facade.community().leidenEstimateWrite(graphName, configuration);
}

@Deprecated(forRemoval = true)
@Internal
@Procedure(value = "gds.beta.leiden.write", mode = WRITE, deprecatedBy = "gds.leiden.write")
@Description(DESCRIPTION)
public Stream<WriteResult> writeBeta(
public Stream<LeidenWriteResult> writeBeta(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
) {
Expand All @@ -97,9 +89,4 @@ public Stream<MemoryEstimateResult> estimateBeta(
return estimate(graphName, configuration);
}


@Override
public ExecutionContext executionContext() {
return super.executionContext().withNodePropertyExporterBuilder(nodePropertyExporterBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.NewConfigFunction;
import org.neo4j.gds.procedures.community.leiden.LeidenWriteResult;
import org.neo4j.gds.result.AbstractResultBuilder;

import java.util.Arrays;
Expand All @@ -40,7 +41,7 @@


@GdsCallable(name = "gds.leiden.write", aliases = {"gds.beta.leiden.write"}, description = DESCRIPTION, executionMode = ExecutionMode.WRITE_NODE_PROPERTY)
public class LeidenWriteSpec implements AlgorithmSpec<Leiden, LeidenResult, LeidenWriteConfig, Stream<WriteResult>, LeidenAlgorithmFactory<LeidenWriteConfig>> {
public class LeidenWriteSpec implements AlgorithmSpec<Leiden, LeidenResult, LeidenWriteConfig, Stream<LeidenWriteResult>, LeidenAlgorithmFactory<LeidenWriteConfig>> {
@Override
public String name() {
return "LeidenWrite";
Expand All @@ -57,7 +58,7 @@ public NewConfigFunction<LeidenWriteConfig> newConfigFunction() {
}

@Override
public ComputationResultConsumer<Leiden, LeidenResult, LeidenWriteConfig, Stream<WriteResult>> computationResultConsumer() {
public ComputationResultConsumer<Leiden, LeidenResult, LeidenWriteConfig, Stream<LeidenWriteResult>> computationResultConsumer() {
return new WriteNodePropertiesComputationResultConsumer<>(
this::resultBuilder,
computationResult -> List.of(ImmutableNodeProperty.of(
Expand All @@ -72,11 +73,11 @@ public ComputationResultConsumer<Leiden, LeidenResult, LeidenWriteConfig, Stream
}

@NotNull
private AbstractResultBuilder<WriteResult> resultBuilder(
private AbstractResultBuilder<LeidenWriteResult> resultBuilder(
ComputationResult<Leiden, LeidenResult, LeidenWriteConfig> computationResult,
ExecutionContext executionContext
) {
var builder = new WriteResult.Builder(
var builder = new LeidenWriteResult.Builder(
executionContext.returnColumns(),
computationResult.config().concurrency()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.neo4j.gds.leiden.LeidenMutateConfig;
import org.neo4j.gds.leiden.LeidenStatsConfig;
import org.neo4j.gds.leiden.LeidenStreamConfig;
import org.neo4j.gds.leiden.LeidenWriteConfig;
import org.neo4j.gds.louvain.LouvainMutateConfig;
import org.neo4j.gds.louvain.LouvainStatsConfig;
import org.neo4j.gds.louvain.LouvainStreamConfig;
Expand Down Expand Up @@ -81,6 +82,7 @@
import org.neo4j.gds.procedures.community.leiden.LeidenMutateResult;
import org.neo4j.gds.procedures.community.leiden.LeidenStatsResult;
import org.neo4j.gds.procedures.community.leiden.LeidenStreamResult;
import org.neo4j.gds.procedures.community.leiden.LeidenWriteResult;
import org.neo4j.gds.procedures.community.louvain.LouvainMutateResult;
import org.neo4j.gds.procedures.community.louvain.LouvainStatsResult;
import org.neo4j.gds.procedures.community.louvain.LouvainStreamResult;
Expand Down Expand Up @@ -517,6 +519,19 @@ public Stream<LeidenStatsResult> leidenStats(
return Stream.of(LeidenComputationResultTransformer.toStatsResult(computationResult, config));
}

public Stream<LeidenWriteResult> leidenWrite(String graphName, Map<String, Object> configuration) {
var config = createConfig(configuration, LeidenWriteConfig::of);

var computationResult = writeBusinessFacade.leiden(
graphName,
config,
user,
databaseId,
ProcedureStatisticsComputationInstructions.forCommunities(procedureReturnColumns)
);

return Stream.of(LeidenComputationResultTransformer.toWriteResult(computationResult));
}

public Stream<MemoryEstimateResult> leidenEstimateStream(
Object graphNameOrConfiguration,
Expand All @@ -542,6 +557,14 @@ public Stream<MemoryEstimateResult> leidenEstimateStats(
return Stream.of(estimateBusinessFacade.leiden(graphNameOrConfiguration, config));
}

public Stream<MemoryEstimateResult> leidenEstimateWrite(
Object graphNameOrConfiguration,
Map<String, Object> algoConfiguration
) {
var config = createConfig(algoConfiguration, LeidenWriteConfig::of);
return Stream.of(estimateBusinessFacade.leiden(graphNameOrConfiguration, config));
}

public Stream<SccStreamResult> sccStream(
String graphName,
Map<String, Object> configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.neo4j.gds.algorithms.LeidenSpecificFields;
import org.neo4j.gds.algorithms.NodePropertyMutateResult;
import org.neo4j.gds.algorithms.NodePropertyWriteResult;
import org.neo4j.gds.algorithms.StatsResult;
import org.neo4j.gds.algorithms.StreamComputationResult;
import org.neo4j.gds.algorithms.community.CommunityResultCompanion;
Expand All @@ -31,6 +32,7 @@
import org.neo4j.gds.procedures.community.leiden.LeidenMutateResult;
import org.neo4j.gds.procedures.community.leiden.LeidenStatsResult;
import org.neo4j.gds.procedures.community.leiden.LeidenStreamResult;
import org.neo4j.gds.procedures.community.leiden.LeidenWriteResult;

import java.util.stream.LongStream;
import java.util.stream.Stream;
Expand Down Expand Up @@ -104,4 +106,21 @@ static LeidenStatsResult toStatsResult(
}


static LeidenWriteResult toWriteResult(NodePropertyWriteResult<LeidenSpecificFields> computationResult) {
return new LeidenWriteResult(
computationResult.algorithmSpecificFields().ranLevels(),
computationResult.algorithmSpecificFields().didConverge(),
computationResult.algorithmSpecificFields().nodeCount(),
computationResult.algorithmSpecificFields().communityCount(),
computationResult.preProcessingMillis(),
computationResult.computeMillis(),
computationResult.postProcessingMillis(),
computationResult.writeMillis(),
computationResult.nodePropertiesWritten(),
computationResult.algorithmSpecificFields().communityDistribution(),
computationResult.algorithmSpecificFields().modularities(),
computationResult.algorithmSpecificFields().modularity(),
computationResult.configuration().toMap()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.leiden;
package org.neo4j.gds.procedures.community.leiden;

import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.api.ProcedureReturnColumns;
import org.neo4j.gds.procedures.community.leiden.LeidenStatsResult;
import org.neo4j.gds.result.AbstractCommunityResultBuilder;

import java.util.List;
import java.util.Map;

public final class WriteResult extends LeidenStatsResult {
public final class LeidenWriteResult extends LeidenStatsResult {

public final long writeMillis;
public final long nodePropertiesWritten;

private WriteResult(
public LeidenWriteResult(
long ranLevels,
boolean didConverge,
long nodeCount,
Expand Down Expand Up @@ -64,41 +63,41 @@ private WriteResult(
this.nodePropertiesWritten = nodePropertiesWritten;
}

static class Builder extends AbstractCommunityResultBuilder<WriteResult> {
public static class Builder extends AbstractCommunityResultBuilder<LeidenWriteResult> {

long levels = -1;
boolean didConverge = false;

double modularity;
List<Double> modularities;

Builder(ProcedureReturnColumns returnColumns, int concurrency) {
public Builder(ProcedureReturnColumns returnColumns, int concurrency) {
super(returnColumns, concurrency);
}

Builder withLevels(long levels) {
public Builder withLevels(long levels) {
this.levels = levels;
return this;
}

Builder withDidConverge(boolean didConverge) {
public Builder withDidConverge(boolean didConverge) {
this.didConverge = didConverge;
return this;
}

Builder withModularity(double modularity) {
public Builder withModularity(double modularity) {
this.modularity = modularity;
return this;
}

Builder withModularities(List<Double> modularities) {
public Builder withModularities(List<Double> modularities) {
this.modularities = modularities;
return this;
}

@Override
protected WriteResult buildResult() {
return new WriteResult(
protected LeidenWriteResult buildResult() {
return new LeidenWriteResult(
levels,
didConverge,
nodeCount,
Expand Down

0 comments on commit 8ecd00b

Please sign in to comment.