diff --git a/conf/hydra/sweeper/fcn_params.yaml b/conf/hydra/sweeper/fcn_params.yaml new file mode 100644 index 0000000..0ca251a --- /dev/null +++ b/conf/hydra/sweeper/fcn_params.yaml @@ -0,0 +1,18 @@ +--- +# +# usage: hydra/sweeper=fcn_params +# +defaults: + - optuna + +sampler: + warn_independent_sampling: true + +study_name: fcn_params +n_trials: 25 +n_jobs: 3 +direction: + - maximize + +params: + opt.lr: tag(log, interval(1.e-5, 1.e-3)) diff --git a/pyproject.toml b/pyproject.toml index e6ffa2a..303cba1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "torch~=2.3.0", "wandb~=0.17.0", "hydra-submitit-launcher~=1.2.0", + "hydra-optuna-sweeper~=1.2.0", ] readme = "README.md" requires-python = ">= 3.10" diff --git a/requirements-dev.lock b/requirements-dev.lock index 2be8ebd..6cf5b67 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -7,9 +7,15 @@ # all-features: false # with-sources: false +alembic==1.13.1 + # via optuna antlr4-python3-runtime==4.9.3 # via hydra-core # via omegaconf +attrs==23.2.0 + # via cmd2 +autopage==0.5.2 + # via cliff certifi==2022.12.7 # via requests # via sentry-sdk @@ -17,8 +23,16 @@ charset-normalizer==2.1.1 # via requests click==8.1.7 # via wandb +cliff==4.7.0 + # via optuna cloudpickle==3.0.0 # via submitit +cmaes==0.10.0 + # via optuna +cmd2==2.4.3 + # via cliff +colorlog==6.8.2 + # via optuna docker-pycreds==0.4.0 # via wandb filelock==3.13.1 @@ -30,21 +44,31 @@ gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via wandb +greenlet==3.0.3 + # via sqlalchemy hydra-core==1.3.2 + # via hydra-optuna-sweeper # via hydra-submitit-launcher +hydra-optuna-sweeper==1.2.0 hydra-submitit-launcher==1.2.0 idna==3.4 # via requests jinja2==3.1.3 # via torch +mako==1.3.5 + # via alembic markupsafe==2.1.5 # via jinja2 + # via mako mpmath==1.3.0 # via sympy networkx==3.2.1 # via torch numpy==1.26.3 + # via cmaes + # via optuna # via ranzen + # via scipy nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12 # via nvidia-cusolver-cu12 @@ -75,21 +99,34 @@ nvidia-nvtx-cu12==12.1.105 # via torch omegaconf==2.3.0 # via hydra-core +optuna==2.10.1 + # via hydra-optuna-sweeper packaging==22.0 # via hydra-core + # via optuna +pbr==6.0.0 + # via stevedore platformdirs==4.2.2 # via wandb +prettytable==3.10.0 + # via cliff protobuf==4.25.3 # via wandb psutil==5.9.8 # via wandb +pyperclip==1.8.2 + # via cmd2 pyyaml==6.0.1 + # via cliff # via omegaconf + # via optuna # via wandb ranzen==2.5.1 requests==2.28.1 # via wandb ruff==0.4.4 +scipy==1.13.0 + # via optuna sentry-sdk==2.2.0 # via wandb setproctitle==1.3.3 @@ -100,18 +137,30 @@ six==1.16.0 # via docker-pycreds smmap==5.0.1 # via gitdb +sqlalchemy==2.0.30 + # via alembic + # via optuna +stevedore==5.2.0 + # via cliff submitit==1.5.1 # via hydra-submitit-launcher sympy==1.12 # via torch torch==2.3.0 +tqdm==4.64.1 + # via optuna triton==2.3.0 # via torch typing-extensions==4.9.0 + # via alembic # via ranzen + # via sqlalchemy # via submitit # via torch urllib3==1.26.13 # via requests # via sentry-sdk wandb==0.17.0 +wcwidth==0.2.13 + # via cmd2 + # via prettytable diff --git a/requirements.lock b/requirements.lock index 488ef0d..63ec4b1 100644 --- a/requirements.lock +++ b/requirements.lock @@ -7,9 +7,15 @@ # all-features: false # with-sources: false +alembic==1.13.1 + # via optuna antlr4-python3-runtime==4.9.3 # via hydra-core # via omegaconf +attrs==23.2.0 + # via cmd2 +autopage==0.5.2 + # via cliff certifi==2022.12.7 # via requests # via sentry-sdk @@ -17,8 +23,16 @@ charset-normalizer==2.1.1 # via requests click==8.1.7 # via wandb +cliff==4.7.0 + # via optuna cloudpickle==3.0.0 # via submitit +cmaes==0.10.0 + # via optuna +cmd2==2.4.3 + # via cliff +colorlog==6.8.2 + # via optuna docker-pycreds==0.4.0 # via wandb filelock==3.13.1 @@ -30,21 +44,31 @@ gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via wandb +greenlet==3.0.3 + # via sqlalchemy hydra-core==1.3.2 + # via hydra-optuna-sweeper # via hydra-submitit-launcher +hydra-optuna-sweeper==1.2.0 hydra-submitit-launcher==1.2.0 idna==3.4 # via requests jinja2==3.1.3 # via torch +mako==1.3.5 + # via alembic markupsafe==2.1.5 # via jinja2 + # via mako mpmath==1.3.0 # via sympy networkx==3.2.1 # via torch numpy==1.26.3 + # via cmaes + # via optuna # via ranzen + # via scipy nvidia-cublas-cu12==12.1.3.1 # via nvidia-cudnn-cu12 # via nvidia-cusolver-cu12 @@ -75,20 +99,33 @@ nvidia-nvtx-cu12==12.1.105 # via torch omegaconf==2.3.0 # via hydra-core +optuna==2.10.1 + # via hydra-optuna-sweeper packaging==22.0 # via hydra-core + # via optuna +pbr==6.0.0 + # via stevedore platformdirs==4.2.2 # via wandb +prettytable==3.10.0 + # via cliff protobuf==4.25.3 # via wandb psutil==5.9.8 # via wandb +pyperclip==1.8.2 + # via cmd2 pyyaml==6.0.1 + # via cliff # via omegaconf + # via optuna # via wandb ranzen==2.5.1 requests==2.28.1 # via wandb +scipy==1.13.0 + # via optuna sentry-sdk==2.2.0 # via wandb setproctitle==1.3.3 @@ -99,18 +136,30 @@ six==1.16.0 # via docker-pycreds smmap==5.0.1 # via gitdb +sqlalchemy==2.0.30 + # via alembic + # via optuna +stevedore==5.2.0 + # via cliff submitit==1.5.1 # via hydra-submitit-launcher sympy==1.12 # via torch torch==2.3.0 +tqdm==4.64.1 + # via optuna triton==2.3.0 # via torch typing-extensions==4.9.0 + # via alembic # via ranzen + # via sqlalchemy # via submitit # via torch urllib3==1.26.13 # via requests # via sentry-sdk wandb==0.17.0 +wcwidth==0.2.13 + # via cmd2 + # via prettytable