diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java index 7a8770c454..120e709063 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java @@ -23,12 +23,13 @@ import org.neo4j.gds.core.ConfigKeyValidation; import org.neo4j.gds.ml.api.TrainingMethod; import org.neo4j.gds.ml.models.automl.TunableTrainerConfig; -import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig; import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig; import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; +import org.neo4j.gds.procedures.GraphDataScienceProcedures; import org.neo4j.gds.procedures.pipelines.PipelineInfoResult; +import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Internal; import org.neo4j.procedure.Name; @@ -40,6 +41,8 @@ import static org.neo4j.procedure.Mode.READ; public class LinkPredictionPipelineAddTrainerMethodProcs extends BaseProc { + @Context + public GraphDataScienceProcedures facade; @Procedure(name = "gds.beta.pipeline.linkPrediction.addLogisticRegression", mode = READ) @Description("Add a logistic regression configuration to the parameter space of the link prediction train pipeline.") @@ -47,17 +50,7 @@ public Stream addLogisticRegression( @Name("pipelineName") String pipelineName, @Name(value = "config", defaultValue = "{}") Map logisticRegressionClassifierConfig ) { - var pipeline = PipelineCatalog.getTyped(username(), pipelineName, LinkPredictionTrainingPipeline.class); - - var allowedKeys = LogisticRegressionTrainConfig.DEFAULT.configKeys(); - ConfigKeyValidation.requireOnlyKeysFrom(allowedKeys, logisticRegressionClassifierConfig.keySet()); - - var tunableTrainerConfig = TunableTrainerConfig.of(logisticRegressionClassifierConfig, TrainingMethod.LogisticRegression); - pipeline.addTrainerConfig( - tunableTrainerConfig - ); - - return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); + return facade.pipelines().linkPrediction().addLogisticRegression(pipelineName, logisticRegressionClassifierConfig); } @Procedure(name = "gds.beta.pipeline.linkPrediction.addRandomForest", mode = READ) diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java new file mode 100644 index 0000000000..4b012535c6 --- /dev/null +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.procedures.pipelines; + +import org.neo4j.gds.api.User; +import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; +import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; + +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Stream; + +class Configurer { + private final PipelineRepository pipelineRepository; + private final User user; + + Configurer(PipelineRepository pipelineRepository, User user) { + this.pipelineRepository = pipelineRepository; + this.user = user; + } + + /** + * Some dull scaffolding + */ + Stream configureLinkPredictionTrainingPipeline( + String pipelineNameAsString, + Supplier configurationSupplier, + BiConsumer action + ) { + return configure( + pipelineNameAsString, + pipelineName -> pipelineRepository.getLinkPredictionTrainingPipeline(user, pipelineName), + configurationSupplier, + action, + PipelineInfoResult::create + ); + } + + /** + * Some more dull scaffolding + */ + Stream configureNodeClassificationTrainingPipeline( + String pipelineNameAsString, + Supplier configurationSupplier, + BiConsumer action + ) { + return configure( + pipelineNameAsString, + pipelineName -> pipelineRepository.getNodeClassificationTrainingPipeline(user, pipelineName), + configurationSupplier, + action, + NodePipelineInfoResult::create + ); + } + + /** + * Some dull and very generic scaffolding + */ + private Stream configure( + String pipelineNameAsString, + Function pipelineSupplier, + Supplier configurationSupplier, + BiConsumer action, + BiFunction resultRenderer + ) { + var pipelineName = PipelineName.parse(pipelineNameAsString); + var pipeline = pipelineSupplier.apply(pipelineName); + + var configuration = configurationSupplier.get(); + + action.accept(pipeline, configuration); + + var r = resultRenderer.apply(pipelineName, pipeline); + + return Stream.of(r); + } +} diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java index b20592b813..c9775541fa 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java @@ -19,21 +19,38 @@ */ package org.neo4j.gds.procedures.pipelines; +import org.neo4j.gds.api.User; +import org.neo4j.gds.ml.pipeline.TrainingPipeline; + import java.util.Map; import java.util.stream.Stream; -public class LinkPredictionFacade { +public final class LinkPredictionFacade { + private final Configurer configurer; + private final PipelineConfigurationParser pipelineConfigurationParser; private final PipelineApplications pipelineApplications; - LinkPredictionFacade( - PipelineConfigurationParser pipelineConfigurationParser, + private LinkPredictionFacade( + Configurer configurer, PipelineConfigurationParser pipelineConfigurationParser, PipelineApplications pipelineApplications ) { + this.configurer = configurer; this.pipelineConfigurationParser = pipelineConfigurationParser; this.pipelineApplications = pipelineApplications; } + static LinkPredictionFacade create( + User user, + PipelineConfigurationParser pipelineConfigurationParser, + PipelineApplications pipelineApplications, + PipelineRepository pipelineRepository + ) { + var configurer = new Configurer(pipelineRepository, user); + + return new LinkPredictionFacade(configurer, pipelineConfigurationParser, pipelineApplications); + } + public Stream addFeature( String pipelineNameAsString, String featureType, @@ -44,11 +61,19 @@ public Stream addFeature( var pipeline = pipelineApplications.addFeature(pipelineName, featureType, configuration); - var result = PipelineInfoResult.create(pipelineName.value, pipeline); + var result = PipelineInfoResult.create(pipelineName, pipeline); return Stream.of(result); } + public Stream addLogisticRegression(String pipelineName, Map configuration) { + return configurer.configureLinkPredictionTrainingPipeline( + pipelineName, + () -> pipelineConfigurationParser.parseLogisticRegressionTrainerConfig(configuration), + TrainingPipeline::addTrainerConfig + ); + } + public Stream addNodeProperty( String pipelineNameAsString, String taskName, @@ -62,7 +87,7 @@ public Stream addNodeProperty( procedureConfig ); - var result = PipelineInfoResult.create(pipelineName.value, pipeline); + var result = PipelineInfoResult.create(pipelineName, pipeline); return Stream.of(result); } diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacade.java index a19c2b395c..160dfcb5dd 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacade.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacade.java @@ -24,26 +24,29 @@ import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult; import org.neo4j.gds.core.model.ModelCatalog; import org.neo4j.gds.ml.pipeline.PipelineCompanion; +import org.neo4j.gds.ml.pipeline.TrainingPipeline; import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.function.BiFunction; -import java.util.function.Supplier; import java.util.stream.Stream; public final class NodeClassificationFacade { + private final Configurer configurer; private final NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor; + private final PipelineConfigurationParser pipelineConfigurationParser; private final PipelineApplications pipelineApplications; NodeClassificationFacade( + Configurer configurer, NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor, PipelineConfigurationParser pipelineConfigurationParser, PipelineApplications pipelineApplications ) { + this.configurer = configurer; this.nodeClassificationPredictConfigPreProcessor = nodeClassificationPredictConfigPreProcessor; this.pipelineConfigurationParser = pipelineConfigurationParser; this.pipelineApplications = pipelineApplications; @@ -53,14 +56,17 @@ static NodeClassificationFacade create( ModelCatalog modelCatalog, User user, PipelineConfigurationParser pipelineConfigurationParser, - PipelineApplications pipelineApplications + PipelineApplications pipelineApplications, + PipelineRepository pipelineRepository ) { + var configurer = new Configurer(pipelineRepository, user); var nodeClassificationPredictConfigPreProcessor = new NodeClassificationPredictConfigPreProcessor( modelCatalog, user ); return new NodeClassificationFacade( + configurer, nodeClassificationPredictConfigPreProcessor, pipelineConfigurationParser, pipelineApplications @@ -71,18 +77,18 @@ public Stream addLogisticRegression( String pipelineName, Map configuration ) { - return configure( + return configurer.configureNodeClassificationTrainingPipeline( pipelineName, () -> pipelineConfigurationParser.parseLogisticRegressionTrainerConfig(configuration), - pipelineApplications::addTrainerConfiguration + TrainingPipeline::addTrainerConfig ); } public Stream addMLP(String pipelineName, Map configuration) { - return configure( + return configurer.configureNodeClassificationTrainingPipeline( pipelineName, () -> pipelineConfigurationParser.parseMLPClassifierTrainConfig(configuration), - pipelineApplications::addTrainerConfiguration + TrainingPipeline::addTrainerConfig ); } @@ -105,28 +111,26 @@ public Stream addNodeProperty( } public Stream addRandomForest(String pipelineName, Map configuration) { - return configure( + return configurer.configureNodeClassificationTrainingPipeline( pipelineName, - () -> pipelineConfigurationParser.parseRandomForestClassifierTrainerConfig( - configuration), - pipelineApplications::addTrainerConfiguration + () -> pipelineConfigurationParser.parseRandomForestClassifierTrainerConfig(configuration), + TrainingPipeline::addTrainerConfig ); } public Stream configureAutoTuning(String pipelineName, Map configuration) { - return configure( + return configurer.configureNodeClassificationTrainingPipeline( pipelineName, () -> pipelineConfigurationParser.parseAutoTuningConfig(configuration), - pipelineApplications::configureAutoTuning + TrainingPipeline::setAutoTuningConfig ); } public Stream configureSplit(String pipelineName, Map configuration) { - return configure( + return configurer.configureNodeClassificationTrainingPipeline( pipelineName, - () -> pipelineConfigurationParser.parseNodePropertyPredictionSplitConfig( - configuration), - pipelineApplications::configureSplit + () -> pipelineConfigurationParser.parseNodePropertyPredictionSplitConfig(configuration), + NodeClassificationTrainingPipeline::setSplitConfig ); } @@ -280,22 +284,6 @@ public Stream writeEstimate( return Stream.of(result); } - private Stream configure( - String pipelineNameAsString, - Supplier configurationSupplier, - BiFunction configurationAction - ) { - var pipelineName = PipelineName.parse(pipelineNameAsString); - - var configuration = configurationSupplier.get(); - - var pipeline = configurationAction.apply(pipelineName, configuration); - - var result = NodePipelineInfoResult.create(pipelineName, pipeline); - - return Stream.of(result); - } - private List parseNodeProperties(Object nodeProperties) { if (nodeProperties instanceof String) return List.of(NodeFeatureStep.of((String) nodeProperties)); diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java index 4f27e82a3c..5ad16dd4d2 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java @@ -47,8 +47,6 @@ import org.neo4j.gds.mem.MemoryEstimations; import org.neo4j.gds.metrics.Metrics; import org.neo4j.gds.ml.models.Classifier; -import org.neo4j.gds.ml.models.automl.TunableTrainerConfig; -import org.neo4j.gds.ml.pipeline.AutoTuningConfig; import org.neo4j.gds.ml.pipeline.NodePropertyStepFactory; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.TrainingPipeline; @@ -56,7 +54,6 @@ import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfiguration; import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep; -import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationModelResult; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo; @@ -69,7 +66,6 @@ import java.util.Map; import java.util.Optional; -import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -274,23 +270,6 @@ NodeClassificationTrainingPipeline addNodePropertyToNodeClassificationPipeline( return pipeline; } - NodeClassificationTrainingPipeline addTrainerConfiguration( - PipelineName pipelineName, - TunableTrainerConfig configuration - ) { - return configure(pipelineName, pipeline -> pipeline.addTrainerConfig(configuration)); - } - - NodeClassificationTrainingPipeline configureAutoTuning(PipelineName pipelineName, AutoTuningConfig configuration) { - return configure(pipelineName, pipeline -> pipeline.setAutoTuningConfig(configuration)); - } - - NodeClassificationTrainingPipeline configureSplit( - PipelineName pipelineName, NodePropertyPredictionSplitConfig configuration - ) { - return configure(pipelineName, pipeline -> pipeline.setSplitConfig(configuration)); - } - NodeClassificationTrainingPipeline createNodeClassificationTrainingPipeline(PipelineName pipelineName) { return pipelineRepository.createNodeClassificationTrainingPipeline(user, pipelineName); } @@ -454,20 +433,6 @@ NodeClassificationTrainingPipeline selectFeatures( return pipeline; } - private NodeClassificationTrainingPipeline configure( - PipelineName pipelineName, - Consumer configurationAction - ) { - var pipeline = pipelineRepository.getNodeClassificationTrainingPipeline( - user, - pipelineName - ); - - configurationAction.accept(pipeline); - - return pipeline; - } - private NodeClassificationPredictComputation constructPredictComputation( NodeClassificationPredictPipelineBaseConfig configuration, Label label diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java index 5c9d3f66b8..ebf0981570 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java @@ -51,6 +51,14 @@ private PipelineInfoResult( this.parameterSpace = parameterSpace; } + public static PipelineInfoResult create(PipelineName pipelineName, LinkPredictionTrainingPipeline pipeline) { + return create(pipelineName.value, pipeline); + } + + /** + * @deprecated migrate to the other one + */ + @Deprecated public static PipelineInfoResult create(String pipelineName, LinkPredictionTrainingPipeline pipeline) { var nodePropertySteps = pipeline .nodePropertySteps() diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineName.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineName.java index db494cdad7..e8991c05df 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineName.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineName.java @@ -22,7 +22,7 @@ import org.neo4j.gds.core.CypherMapAccess; import org.neo4j.gds.core.StringIdentifierValidations; -final class PipelineName { +public final class PipelineName { final String value; private PipelineName(String value) { diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java index 795ed3f69f..c632b0351c 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java @@ -112,13 +112,19 @@ public static PipelinesProcedureFacade create( algorithmProcessingTemplate ); - var linkPredictionFacade = new LinkPredictionFacade(pipelineConfigurationParser, pipelineApplications); + var linkPredictionFacade = LinkPredictionFacade.create( + user, + pipelineConfigurationParser, + pipelineApplications, + pipelineRepository + ); var nodeClassificationFacade = NodeClassificationFacade.create( modelCatalog, user, pipelineConfigurationParser, - pipelineApplications + pipelineApplications, + pipelineRepository ); return new PipelinesProcedureFacade( diff --git a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacadeTest.java b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacadeTest.java index 740c0dd2ba..a83de4dd59 100644 --- a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacadeTest.java +++ b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/NodeClassificationFacadeTest.java @@ -61,7 +61,7 @@ void createPipeline() { ); var facade2 = new PipelinesProcedureFacade(applications, null, null); - var facade = new NodeClassificationFacade(null, null, applications); + var facade = new NodeClassificationFacade(null, null, null, applications); var result = facade.createPipeline("myPipeline").findAny().orElseThrow(); @@ -104,7 +104,7 @@ void shouldNotCreatePipelineWhenOneExists() { null, null ); - var facade = new NodeClassificationFacade(null, null, applications); + var facade = new NodeClassificationFacade(null, null, null, applications); assertThatIllegalStateException() .isThrownBy(() -> facade.createPipeline("myPipeline")) @@ -113,11 +113,7 @@ void shouldNotCreatePipelineWhenOneExists() { @Test void shouldNotCreatePipelineWithInvalidName() { - var facade = new NodeClassificationFacade( - null, - null, - null - ); + var facade = new NodeClassificationFacade(null, null, null, null); assertThatIllegalArgumentException() .isThrownBy(() -> facade.createPipeline(" blanks!"))