Skip to content

Commit

Permalink
migrate link prediction configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Oct 10, 2024
1 parent 4c082c0 commit c23175a
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
import org.neo4j.gds.core.ConfigKeyValidation;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Name;
Expand All @@ -40,24 +41,16 @@
import static org.neo4j.procedure.Mode.READ;

public class LinkPredictionPipelineAddTrainerMethodProcs extends BaseProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.beta.pipeline.linkPrediction.addLogisticRegression", mode = READ)
@Description("Add a logistic regression configuration to the parameter space of the link prediction train pipeline.")
public Stream<PipelineInfoResult> addLogisticRegression(
@Name("pipelineName") String pipelineName,
@Name(value = "config", defaultValue = "{}") Map<String, Object> logisticRegressionClassifierConfig
) {
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, LinkPredictionTrainingPipeline.class);

var allowedKeys = LogisticRegressionTrainConfig.DEFAULT.configKeys();
ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, logisticRegressionClassifierConfig.keySet());

var tunableTrainerConfig = TunableTrainerConfig.of(logisticRegressionClassifierConfig, TrainingMethod.LogisticRegression);
pipeline.addTrainerConfig(
tunableTrainerConfig
);

return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
return facade.pipelines().linkPrediction().addLogisticRegression(pipelineName, logisticRegressionClassifierConfig);
}

@Procedure(name = "gds.beta.pipeline.linkPrediction.addRandomForest", mode = READ)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.procedures.pipelines;

import org.neo4j.gds.api.User;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;

import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;

class Configurer {
private final PipelineRepository pipelineRepository;
private final User user;

Configurer(PipelineRepository pipelineRepository, User user) {
this.pipelineRepository = pipelineRepository;
this.user = user;
}

/**
* Some dull scaffolding
*/
<CONFIGURATION> Stream<PipelineInfoResult> configureLinkPredictionTrainingPipeline(
String pipelineNameAsString,
Supplier<CONFIGURATION> configurationSupplier,
BiConsumer<LinkPredictionTrainingPipeline, CONFIGURATION> action
) {
return configure(
pipelineNameAsString,
pipelineName -> pipelineRepository.getLinkPredictionTrainingPipeline(user, pipelineName),
configurationSupplier,
action,
PipelineInfoResult::create
);
}

/**
* Some more dull scaffolding
*/
<CONFIGURATION> Stream<NodePipelineInfoResult> configureNodeClassificationTrainingPipeline(
String pipelineNameAsString,
Supplier<CONFIGURATION> configurationSupplier,
BiConsumer<NodeClassificationTrainingPipeline, CONFIGURATION> action
) {
return configure(
pipelineNameAsString,
pipelineName -> pipelineRepository.getNodeClassificationTrainingPipeline(user, pipelineName),
configurationSupplier,
action,
NodePipelineInfoResult::create
);
}

/**
* Some dull and very generic scaffolding
*/
private <CONFIGURATION, PIPELINE, RESULT> Stream<RESULT> configure(
String pipelineNameAsString,
Function<PipelineName, PIPELINE> pipelineSupplier,
Supplier<CONFIGURATION> configurationSupplier,
BiConsumer<PIPELINE, CONFIGURATION> action,
BiFunction<PipelineName, PIPELINE, RESULT> resultRenderer
) {
var pipelineName = PipelineName.parse(pipelineNameAsString);
var pipeline = pipelineSupplier.apply(pipelineName);

var configuration = configurationSupplier.get();

action.accept(pipeline, configuration);

var r = resultRenderer.apply(pipelineName, pipeline);

return Stream.of(r);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,38 @@
*/
package org.neo4j.gds.procedures.pipelines;

import org.neo4j.gds.api.User;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;

import java.util.Map;
import java.util.stream.Stream;

public class LinkPredictionFacade {
public final class LinkPredictionFacade {
private final Configurer configurer;

private final PipelineConfigurationParser pipelineConfigurationParser;
private final PipelineApplications pipelineApplications;

LinkPredictionFacade(
PipelineConfigurationParser pipelineConfigurationParser,
private LinkPredictionFacade(
Configurer configurer, PipelineConfigurationParser pipelineConfigurationParser,
PipelineApplications pipelineApplications
) {
this.configurer = configurer;
this.pipelineConfigurationParser = pipelineConfigurationParser;
this.pipelineApplications = pipelineApplications;
}

static LinkPredictionFacade create(
User user,
PipelineConfigurationParser pipelineConfigurationParser,
PipelineApplications pipelineApplications,
PipelineRepository pipelineRepository
) {
var configurer = new Configurer(pipelineRepository, user);

return new LinkPredictionFacade(configurer, pipelineConfigurationParser, pipelineApplications);
}

public Stream<PipelineInfoResult> addFeature(
String pipelineNameAsString,
String featureType,
Expand All @@ -44,11 +61,19 @@ public Stream<PipelineInfoResult> addFeature(

var pipeline = pipelineApplications.addFeature(pipelineName, featureType, configuration);

var result = PipelineInfoResult.create(pipelineName.value, pipeline);
var result = PipelineInfoResult.create(pipelineName, pipeline);

return Stream.of(result);
}

public Stream<PipelineInfoResult> addLogisticRegression(String pipelineName, Map<String, Object> configuration) {
return configurer.configureLinkPredictionTrainingPipeline(
pipelineName,
() -> pipelineConfigurationParser.parseLogisticRegressionTrainerConfig(configuration),
TrainingPipeline::addTrainerConfig
);
}

public Stream<PipelineInfoResult> addNodeProperty(
String pipelineNameAsString,
String taskName,
Expand All @@ -62,7 +87,7 @@ public Stream<PipelineInfoResult> addNodeProperty(
procedureConfig
);

var result = PipelineInfoResult.create(pipelineName.value, pipeline);
var result = PipelineInfoResult.create(pipelineName, pipeline);

return Stream.of(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,29 @@
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Stream;

public final class NodeClassificationFacade {
private final Configurer configurer;
private final NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor;

private final PipelineConfigurationParser pipelineConfigurationParser;
private final PipelineApplications pipelineApplications;

NodeClassificationFacade(
Configurer configurer,
NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor,
PipelineConfigurationParser pipelineConfigurationParser,
PipelineApplications pipelineApplications
) {
this.configurer = configurer;
this.nodeClassificationPredictConfigPreProcessor = nodeClassificationPredictConfigPreProcessor;
this.pipelineConfigurationParser = pipelineConfigurationParser;
this.pipelineApplications = pipelineApplications;
Expand All @@ -53,14 +56,17 @@ static NodeClassificationFacade create(
ModelCatalog modelCatalog,
User user,
PipelineConfigurationParser pipelineConfigurationParser,
PipelineApplications pipelineApplications
PipelineApplications pipelineApplications,
PipelineRepository pipelineRepository
) {
var configurer = new Configurer(pipelineRepository, user);
var nodeClassificationPredictConfigPreProcessor = new NodeClassificationPredictConfigPreProcessor(
modelCatalog,
user
);

return new NodeClassificationFacade(
configurer,
nodeClassificationPredictConfigPreProcessor,
pipelineConfigurationParser,
pipelineApplications
Expand All @@ -71,18 +77,18 @@ public Stream<NodePipelineInfoResult> addLogisticRegression(
String pipelineName,
Map<String, Object> configuration
) {
return configure(
return configurer.configureNodeClassificationTrainingPipeline(
pipelineName,
() -> pipelineConfigurationParser.parseLogisticRegressionTrainerConfig(configuration),
pipelineApplications::addTrainerConfiguration
TrainingPipeline::addTrainerConfig
);
}

public Stream<NodePipelineInfoResult> addMLP(String pipelineName, Map<String, Object> configuration) {
return configure(
return configurer.configureNodeClassificationTrainingPipeline(
pipelineName,
() -> pipelineConfigurationParser.parseMLPClassifierTrainConfig(configuration),
pipelineApplications::addTrainerConfiguration
TrainingPipeline::addTrainerConfig
);
}

Expand All @@ -105,28 +111,26 @@ public Stream<NodePipelineInfoResult> addNodeProperty(
}

public Stream<NodePipelineInfoResult> addRandomForest(String pipelineName, Map<String, Object> configuration) {
return configure(
return configurer.configureNodeClassificationTrainingPipeline(
pipelineName,
() -> pipelineConfigurationParser.parseRandomForestClassifierTrainerConfig(
configuration),
pipelineApplications::addTrainerConfiguration
() -> pipelineConfigurationParser.parseRandomForestClassifierTrainerConfig(configuration),
TrainingPipeline::addTrainerConfig
);
}

public Stream<NodePipelineInfoResult> configureAutoTuning(String pipelineName, Map<String, Object> configuration) {
return configure(
return configurer.configureNodeClassificationTrainingPipeline(
pipelineName,
() -> pipelineConfigurationParser.parseAutoTuningConfig(configuration),
pipelineApplications::configureAutoTuning
TrainingPipeline::setAutoTuningConfig
);
}

public Stream<NodePipelineInfoResult> configureSplit(String pipelineName, Map<String, Object> configuration) {
return configure(
return configurer.configureNodeClassificationTrainingPipeline(
pipelineName,
() -> pipelineConfigurationParser.parseNodePropertyPredictionSplitConfig(
configuration),
pipelineApplications::configureSplit
() -> pipelineConfigurationParser.parseNodePropertyPredictionSplitConfig(configuration),
NodeClassificationTrainingPipeline::setSplitConfig
);
}

Expand Down Expand Up @@ -280,22 +284,6 @@ public Stream<MemoryEstimateResult> writeEstimate(
return Stream.of(result);
}

private <CONFIGURATION> Stream<NodePipelineInfoResult> configure(
String pipelineNameAsString,
Supplier<CONFIGURATION> configurationSupplier,
BiFunction<PipelineName, CONFIGURATION, NodeClassificationTrainingPipeline> configurationAction
) {
var pipelineName = PipelineName.parse(pipelineNameAsString);

var configuration = configurationSupplier.get();

var pipeline = configurationAction.apply(pipelineName, configuration);

var result = NodePipelineInfoResult.create(pipelineName, pipeline);

return Stream.of(result);
}

private List<NodeFeatureStep> parseNodeProperties(Object nodeProperties) {
if (nodeProperties instanceof String) return List.of(NodeFeatureStep.of((String) nodeProperties));

Expand Down
Loading

0 comments on commit c23175a

Please sign in to comment.