Skip to content

Commit

Permalink
Add an explanation of experiments and other things
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed May 16, 2024
1 parent e3fe893 commit 732708e
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 8 deletions.
95 changes: 95 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,17 @@ python main.py --help
```

It should look like this:

<details>
<summary> Help output (quite long) </summary>

```
== Configuration groups ==
Compose your configuration from those groups (group=option)
dm: celeba, cmnist
dm/celeba: male_blond, male_smiling
experiment: cmnist_cnn, good_model
model: cnn, fcn, single_linear_layer
Expand All @@ -78,6 +83,11 @@ model:
dropout_prob: 0.0
final_bias: true
input_norm: false
opt:
lr: 5.0e-05
weight_decay: 0.0
optimizer_cls: torch.optim.AdamW
optimizer_kwargs: null
wandb:
name: null
dir: ./local_logging
Expand All @@ -94,6 +104,8 @@ seed: 42
gpu: 0
```

</details>

We can see that `seed` and `gpu` are "top-level" config values, but there are also a lot of values in subconfigs. For example, we can change the `num_hidden` value in the `model` config:
```bash
python main.py model.num_hidden=2
Expand Down Expand Up @@ -149,6 +161,12 @@ You can also set
python main.py dm=celeba/male_smiling
```

### W&B config
To enable W&B logging, you can set the `wandb.mode` to "online" or "offline" (default is "disabled"):
```bash
python main.py wandb.mode=online
```

## Multiruns
Often, we want to run the same experiment with only slightly different configurations. This can be done with the `--multirun` flag. For example, to run the code with seeds 42 and 43, we can do:
```bash
Expand Down Expand Up @@ -186,13 +204,90 @@ By default, Hydra simply runs the code with the different config values sequenti
python main.py --multirun seed=42,43 hydra/launcher=slurm/kyiv
```

Hydra doesn't limit you to iterating over just one parameter. You can also iterate over multiple parameters. For example, to run the code with seeds 42 and 43 and `model.num_hidden` set to 1, 2 and 3, you can do:
```bash
python main.py --multirun seed=42,43 model.hidden_dim=10 model.num_hidden=1,2,3 gpu=0
```
This will start 6 jobs: 2 seeds x 3 `model.num_hidden` values.

## Experiment configs
Hydra commands can get very long when we specify many config values. To make this easier, we can define "experiment configs" in the `conf/experiment/` directory.

For example, let's say that after a lot of experimentation, we found that the following config values work well together:
```bash
python main.py model=fcn model.num_hidden=2 model.hidden_dim=10 model.norm=LN model.input_norm=true model.activation=SELU opt.lr=0.001 opt.weight_decay=0.001
```

We can then create a new experiment config file at `conf/experiment/good_model.yaml` with the following content:
```yaml
# @package _global_
---
defaults:
- override /model: fcn

model:
num_hidden: 2
hidden_dim: 10
norm: LN
input_norm: true
activation: SELU

opt:
lr: 0.001
weight_decay: 0.001
```
Note that the comment `# @package _global_` is required. (The reason is that, by default, if you have a config file in the `conf/experiment/` directory, Hydra will want to associate this with the `experiment` entry in the main configuration – which doesn't exist! So, `@package _global_` tells Hydra to put the content of the file at the *top level* of the main config.)

And then we can run the code with these config values by running:
```bash
python main.py +experiment=good_model
```

You can still override values:
```bash
python main.py +experiment=good_model model.hidden_dim=20
```

## How to make your code easily configurable with Hydra
What we found is the best method to structure your code to make it easy to configure with Hydra is the "builder" or "factory" pattern.

This means you have a dataclass that contains all the configuration values for a particular component of your code (e.g. the model, the data, the optimiser, etc.). And then this class has a `build()` or `init()` method that takes additional arguments which are only available at runtime (e.g. the input size of the model, the number of classes in the dataset, etc.), and then instantiates the component.

For example, in this code base, we have the `ModelFactory` class in `src/model.py`:
```python
@dataclass(eq=False)
class ModelFactory(ABC):
"""Interface for model factories."""
@abstractmethod
def build(self, in_dim: int, *, out_dim: int) -> nn.Module:
raise NotImplementedError()
```
And when you add a model to the code base, you subclass `ModelFactory` and implement the `build()` method:
```python
@dataclass(eq=False, kw_only=True)
class SimpleCNNFactory(ModelFactory):
"""Factory for a very simple CNN."""
kernel_size: int = 5
pool_stride: int = 2
activation: Activation = Activation.RELU
@override
def build(self, in_dim: int, *, out_dim: int) -> nn.Sequential:
...
```


## Structure of the code

- `main.py`: The main entry point of the code. It sets up Hydra and then calls the main `run()` function.
- `src/`
- `run.py`: Contains the main `Config` class that is used to define valid config values. It also contains the `run()` function that is called by `main.py`.
- `data.py`: Contains the `DataModule` class that is used to load the data.
- `model.py`: Contains the `ModelFactory` class that is used to create the model.
- `optimisation.py`: Contains the `OptimisationCfg` class that is used to build the optimiser that trains the model.
- `logging.py`: Contains the `WandbCfg` class that is used to set up Weights & Biases logging.
- `conf/`
- `config.yaml`: The main config file for the project. It sets the default values for `dm` and `model`.
Expand Down
1 change: 1 addition & 0 deletions conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
defaults:
# The first entry here always has to be the name we have passed to `register_hydra_config`.
- config_schema
Expand Down
1 change: 1 addition & 0 deletions conf/dm/celeba/male_blond.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
defaults:
# Specify here the name of the entry in the `dm` config group you want to extend.
- celeba
Expand Down
1 change: 1 addition & 0 deletions conf/dm/celeba/male_smiling.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
defaults:
# Specify here the name of the entry in the `dm` config group you want to extend.
- celeba
Expand Down
12 changes: 12 additions & 0 deletions conf/experiment/cmnist_cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# @package _global_
---
defaults:
- override /dm: cmnist
- override /model: cnn

dm:
num_colors: 10
val_prop: 0.2

model:
kernel_size: 5
15 changes: 15 additions & 0 deletions conf/experiment/good_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# @package _global_
---
defaults:
- override /model: fcn

model:
num_hidden: 2
hidden_dim: 10
norm: LN
input_norm: true
activation: SELU

opt:
lr: 0.001
weight_decay: 0.001
1 change: 1 addition & 0 deletions conf/model/single_linear_layer.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
---
defaults:
- fcn

Expand Down
6 changes: 4 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def main(hydra_config: omegaconf.DictConfig) -> None:

if __name__ == "__main__":
# Before calling the main function, we need to register the main `Config` class and
# the configuration groups.
# Without this, hydra doesn't know which keys and values are valid in the configuration.
# the configuration groups. Without this, hydra doesn't know which keys and values are valid in
# the configuration.
# Whatever you set here as `schema_name` will need to be incluced as the first entry in the
# `defaults` list in the main config yaml file (`conf/config.yaml`).
register_hydra_config(Config, CONFIG_GROUPS, schema_name="config_schema")
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pythonVersion = "3.10"
typeCheckingMode = "strict"
venvPath = "."
venv = ".venv"
reportUnknownMemberType = "none"

[tool.ruff]
line-length = 100
Expand Down
2 changes: 1 addition & 1 deletion src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SimpleCNNFactory(ModelFactory):
activation: Activation = Activation.RELU

@override
def build(self, in_dim: int, *, out_dim: int) -> nn.Module:
def build(self, in_dim: int, *, out_dim: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=6, kernel_size=self.kernel_size),
self.activation.init(),
Expand Down
24 changes: 24 additions & 0 deletions src/optimisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Optimisation and training code."""

from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any

from torch.nn import Parameter
from torch.optim import Optimizer

__all__ = ["OptimisationCfg"]


@dataclass
class OptimisationCfg:
"""Config class for the optimisation."""

lr: float = 5.0e-5
weight_decay: float = 0.0
optimizer_cls: str = "torch.optim.AdamW"
optimizer_kwargs: dict[str, Any] | None = None

def build(self, params: Iterator[tuple[str, Parameter]]) -> Optimizer:
# Instantiate the optimizer for the given parameters.
raise NotImplementedError()
12 changes: 7 additions & 5 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from src.datasets import CelebADataModule, ColoredMNISTDataModule, DataModule
from src.logging import WandbCfg
from src.models import FcnFactory, ModelFactory, SimpleCNNFactory
from src.optimisation import OptimisationCfg

__all__ = ["Config", "CONFIG_GROUPS"]

Expand All @@ -30,8 +31,9 @@ class Config:
dm: DataModule
model: ModelFactory

# This is a normal subconfig, for which we can specify a default,
# but note that in dataclasses, the default cannot be mutable, so we use `default_factory`.
# These are normal subconfigs, for which we can specify defaults,
# but note that in dataclasses, the default may not be mutable, so we use `default_factory`.
opt: OptimisationCfg = field(default_factory=OptimisationCfg)
wandb: WandbCfg = field(default_factory=WandbCfg)

# These are normal fields, for which we can specify defaults.
Expand All @@ -45,9 +47,9 @@ def run(self, config_for_logging: dict[str, Any]) -> None:
torch.manual_seed(self.seed)

# Initialize the logger.
run = self.wandb.init(config_for_logging, reinit=True)
if run is not None:
run.log({"accuracy": 0.5})
wandb_run = self.wandb.init(config_for_logging, reinit=True)
if wandb_run is not None:
wandb_run.log({"accuracy": 0.5})

# Prepare the data module.
self.dm.prepare(seed=self.seed)
Expand Down

0 comments on commit 732708e

Please sign in to comment.