diff --git a/service/models/operations/calibrate.py b/service/models/operations/calibrate.py index bb9ddd6..062c3a6 100644 --- a/service/models/operations/calibrate.py +++ b/service/models/operations/calibrate.py @@ -25,19 +25,12 @@ class CalibrateExtra(BaseModel): num_iterations: int = Field( 1000, description="Optional field for CIEMSS calibration", example=1000 ) - lr: float = Field( - 0.03, description="Optional field for CIEMSS calibration", example=0.03 - ) verbose: bool = Field( False, description="Optional field for CIEMSS calibration", example=False ) num_particles: int = Field( 1, description="Optional field for CIEMSS calibration", example=1 ) - # autoguide: pyro.infer.autoguide.AutoLowRankMultivariateNormal - solver_method: str = Field( - "dopri5", description="Optional field for CIEMSS calibration", example="dopri5" - ) class Calibrate(OperationRequest): @@ -46,6 +39,21 @@ class Calibrate(OperationRequest): dataset: Dataset = None timespan: Optional[Timespan] = None policy_intervention_id: str = Field(None, example="ba8da8d4-047d-11ee-be56") + learning_rate: float = Field( + 0.03, description="Optional field for CIEMSS calibration", example=0.03 + ) + solver_method: str = Field( + "dopri5", + description="Optional field for CIEMSS calibration", + example="dopri5", + ) + + # https://github.com/ciemss/pyciemss/blob/main/pyciemss/integration_utils/interface_checks.py + solver_step_size: float = Field( + None, + description="id from a previous calibration", + example=1.0, + ) extra: CalibrateExtra = Field( None, description="optional extra system specific arguments for advanced use cases", @@ -74,6 +82,11 @@ def hook(progress, _loss): logging.info(f"Calibration is {progress}% complete") return None + extra_options = self.extra.dict() + solver_options = {} + if self.solver_step_size is not None: + solver_options = {"step_size": self.solver_step_size} + return { "model_path_or_json": amr_path, "start_time": self.timespan.start, @@ -82,8 +95,11 @@ def hook(progress, _loss): "data_path": dataset_path, "static_parameter_interventions": static_interventions, "progress_hook": hook, + "lr": self.learning_rate, + "solver_method": self.solver_method, + "solver_options": solver_options, # "visual_options": True, - **self.extra.dict(), + **extra_options, } class Config: diff --git a/service/models/operations/simulate.py b/service/models/operations/simulate.py index 2aeeebb..45da154 100644 --- a/service/models/operations/simulate.py +++ b/service/models/operations/simulate.py @@ -2,8 +2,6 @@ from typing import ClassVar, Optional from pydantic import BaseModel, Field, Extra - - from models.base import OperationRequest, Timespan from models.converters import ( fetch_and_convert_static_interventions, @@ -28,7 +26,18 @@ class Simulate(OperationRequest): model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") timespan: Timespan = Timespan(start=0, end=90) policy_intervention_id: str = Field(None, example="ba8da8d4-047d-11ee-be56") - step_size: float = 1.0 + solver_method: str = Field( + "dopri5", + description="Optional field for CIEMSS calibration", + example="dopri5", + ) + logging_step_size: float = 1.0 + # https://github.com/ciemss/pyciemss/blob/main/pyciemss/integration_utils/interface_checks.py + solver_step_size: float = Field( + None, + description="id from a previous calibration", + example=1.0, + ) extra: SimulateExtra = Field( None, description="optional extra system specific arguments for advanced use cases", @@ -51,14 +60,20 @@ def gen_pyciemss_args(self, job_id): extra_options.pop("inferred_parameters"), job_id ) + solver_options = {} + if self.solver_step_size is not None: + solver_options = {"step_size": self.solver_step_size} + return { "model_path_or_json": amr_path, - "logging_step_size": self.step_size, + "logging_step_size": self.logging_step_size, "start_time": self.timespan.start, "end_time": self.timespan.end, "static_parameter_interventions": static_interventions, "dynamic_parameter_interventions": dynamic_interventions, "inferred_parameters": inferred_parameters, + "solver_method": self.solver_method, + "solver_options": solver_options, **extra_options, } diff --git a/tests/examples/calibrate/input/request.json b/tests/examples/calibrate/input/request.json index 3229711..e4f3d0f 100644 --- a/tests/examples/calibrate/input/request.json +++ b/tests/examples/calibrate/input/request.json @@ -14,13 +14,13 @@ "start": 0, "end": 90 }, + "learning_rate": 0.3, + "solver_method": "dopri5", "extra": { "num_samples": 100, "start_time": -1e-10, "num_iterations": 1000, - "lr": 0.03, "verbose": false, - "num_particles": 1, - "method": "dopri5" + "num_particles": 1 } }