diff --git a/external_confs/ds/acs/employment_dis_fl.yaml b/external_confs/ds/acs/employment_dis_fl.yaml new file mode 100644 index 00000000..9721ba07 --- /dev/null +++ b/external_confs/ds/acs/employment_dis_fl.yaml @@ -0,0 +1,10 @@ +--- +defaults: + - acs + +setting: employment_disability +survey_year: YEAR_2018 +states: + - FL +survey: PERSON +horizon: ONE_YEAR \ No newline at end of file diff --git a/external_confs/split/acs/employment_dis.yaml b/external_confs/split/acs/employment_dis.yaml new file mode 100644 index 00000000..800c5c8c --- /dev/null +++ b/external_confs/split/acs/employment_dis.yaml @@ -0,0 +1,8 @@ +--- +defaults: + - tabular + +seed: 0 +train_props: + 1: + 0: 0.0 \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index c2247fd2..7b9045d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -543,6 +543,7 @@ files = [ [package.dependencies] filelock = "*" +folktables = {version = ">=0.0.12", optional = true, markers = "extra == \"data\" or extra == \"all\""} jinja2 = "*" joblib = ">=1.1.0,<2.0.0" networkx = "*" @@ -578,6 +579,23 @@ files = [ docs = ["furo (>=2023.3.27)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +[[package]] +name = "folktables" +version = "0.0.12" +description = "New machine learning benchmarks from tabular datasets." +optional = false +python-versions = ">=3.7" +files = [ + {file = "folktables-0.0.12-py3-none-any.whl", hash = "sha256:979cda1900094b845ab3a8a3ae1b848f0138b780d5f8d17eeb6eb04c3c0c6617"}, + {file = "folktables-0.0.12.tar.gz", hash = "sha256:e83dde0cbcdd54c7c39b175006a50bdfc4adc351f69d4389f82aaba3eee02115"}, +] + +[package.dependencies] +numpy = "*" +pandas = "*" +requests = "*" +scikit-learn = "*" + [[package]] name = "fonttools" version = "4.39.4" @@ -2867,6 +2885,8 @@ files = [ [package.dependencies] albumentations = {version = ">=1.0.0,<2.0.0", optional = true, markers = "extra == \"image\" or extra == \"all\""} attrs = ">=21.2.0" +ethicml = {version = ">=1.2.1,<2.0.0", extras = ["data"], optional = true, markers = "extra == \"fair\" or extra == \"all\""} +folktables = {version = ">=0.0.12,<0.0.13", optional = true, markers = "extra == \"fair\" or extra == \"all\""} numpy = ">=1.22.3,<2.0.0" opencv-python = {version = ">=4.5.3,<5.0.0", optional = true, markers = "extra == \"image\" or extra == \"all\""} pandas = ">=1.3.3,<3.0" @@ -3011,13 +3031,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.5.0" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.10.0" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, - {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, + {file = "typing_extensions-4.10.0-py3-none-any.whl", hash = "sha256:69b1a937c3a517342112fb4c6df7e72fc39a38e7891a5730ed4985b5214b5475"}, + {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] [[package]] @@ -3238,4 +3258,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "d9315641cde83b4dfcdabdeb8e6288e4f2f0ba8737e29ad9b11071d921c39d3a" +content-hash = "9b3df6d30be84a67b9f788c688a29f735b5a3f2a07915dbc3b09c542ad00f1c5" diff --git a/pyproject.toml b/pyproject.toml index 36ac4623..3d8fecd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ scikit-image = ">=0.14" scikit_learn = { version = ">=0.20.1" } scipy = { version = ">=1.2.1" } seaborn = { version = ">=0.9.0" } -torch-conduit = { version = ">=0.3.4", extras = ["image"] } +torch-conduit = { version = ">=0.3.4", extras = ["image", "fair"] } +typing_extensions = ">= 4.10" tqdm = { version = ">=4.31.1" } typer = "*" @@ -56,7 +57,7 @@ torchvision = ">=0.15.2" ruff = "*" types-tqdm = "*" pandas-stubs = "*" -python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608"} +python-type-stubs = { git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608" } [build-system] requires = ["poetry-core>=1.0.0"] @@ -67,6 +68,9 @@ target-version = "py310" line-length = 100 extend-exclude = ["hydra_plugins"] +[tool.ruff.format] +quote-style = "preserve" + [tool.ruff.lint] select = ["I", "F", "E", "W", "UP"] ignore = [ diff --git a/src/arch/autoencoder/vqgan.py b/src/arch/autoencoder/vqgan.py index 7f77a85c..6c5f6a74 100644 --- a/src/arch/autoencoder/vqgan.py +++ b/src/arch/autoencoder/vqgan.py @@ -43,7 +43,7 @@ def __init__(self, in_channels: int, *, with_conv: bool) -> None: else: self.conv = None - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: if self.conv is not None: pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) @@ -81,7 +81,7 @@ def __init__( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: h = x h = self.norm1(h) h = F.silu(h) @@ -112,7 +112,7 @@ def __init__(self, in_channels: int): self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: h_ = x h_ = self.norm(h_) q = self.q(h_) @@ -201,7 +201,7 @@ def __init__( nn.Linear(flattened_size, out_features=latent_dim), ) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: # timestep embedding # downsampling hs = [self.conv_in(x)] @@ -288,7 +288,7 @@ def __init__( self.norm_out = Normalize(block_in) self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) - def forward(self, z: Tensor) -> Tensor: # type: ignore + def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.from_latent(z) diff --git a/src/data/__init__.py b/src/data/__init__.py index f2124c5f..5ad2129a 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -1,6 +1,6 @@ from .common import * from .data_module import * -from .nico_plus_plus import * +from .factories import * from .nih import * from .splitter import * from .utils import * diff --git a/src/data/common.py b/src/data/common.py index 05733c3f..e81ba82c 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -9,7 +9,6 @@ from conduit.data import LoadedData, TernarySample, UnloadedData from conduit.data.datasets import CdtDataset from hydra.utils import to_absolute_path -from numpy import typing as npt from torch import Tensor __all__ = [ @@ -78,5 +77,5 @@ def num_samples_te(self) -> int: class DatasetFactory(ABC): @abstractmethod - def __call__(self) -> Dataset[npt.NDArray]: + def __call__(self) -> Dataset: raise NotImplementedError() diff --git a/src/data/factories.py b/src/data/factories.py new file mode 100644 index 00000000..170bd1ff --- /dev/null +++ b/src/data/factories.py @@ -0,0 +1,50 @@ +"""Dataset factories.""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional, Union +from typing_extensions import override + +from conduit.data.datasets.vision import NICOPP, NicoPPTarget +from conduit.fair.data.datasets import ( + ACSDataset, + ACSHorizon, + ACSSetting, + ACSState, + ACSSurvey, + ACSSurveyYear, +) + +from src.data.common import DatasetFactory + +__all__ = ["NICOPPCfg"] + + +@dataclass +class NICOPPCfg(DatasetFactory): + root: Union[Path, str] + target_attrs: Optional[list[NicoPPTarget]] = None + transform: Any = None # Optional[Union[Compose, BasicTransform, Callable[[Image], Any]]] + + @override + def __call__(self) -> NICOPP: + return NICOPP(root=self.root, transform=self.transform, superclasses=self.target_attrs) + + +@dataclass +class ACSCfg(DatasetFactory): + setting: ACSSetting + survey_year: ACSSurveyYear = ACSSurveyYear.YEAR_2018 + horizon: ACSHorizon = ACSHorizon.ONE_YEAR + survey: ACSSurvey = ACSSurvey.PERSON + states: list[ACSState] = field(default_factory=lambda: [ACSState.AL]) + + @override + def __call__(self) -> ACSDataset: + return ACSDataset( + setting=self.setting, + survey_year=self.survey_year, + horizon=self.horizon, + survey=self.survey, + states=self.states, + ) diff --git a/src/data/nico_plus_plus.py b/src/data/nico_plus_plus.py deleted file mode 100644 index d3673f46..00000000 --- a/src/data/nico_plus_plus.py +++ /dev/null @@ -1,24 +0,0 @@ -"""NICO Dataset.""" -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Optional, Union -from typing_extensions import override - -from conduit.data.datasets.vision import CdtVisionDataset, NICOPP, NicoPPTarget -from conduit.data.structures import TernarySample -from torch import Tensor - -from src.data.common import DatasetFactory - -__all__ = ["NICOPPCfg"] - - -@dataclass -class NICOPPCfg(DatasetFactory): - root: Union[Path, str] - target_attrs: Optional[list[NicoPPTarget]] = None - transform: Any = None # Optional[Union[Compose, BasicTransform, Callable[[Image], Any]]] - - @override - def __call__(self) -> CdtVisionDataset[TernarySample, Tensor, Tensor]: - return NICOPP(root=self.root, transform=self.transform, superclasses=self.target_attrs) diff --git a/src/data/splitter.py b/src/data/splitter.py index 48c22a18..2d1af4c0 100644 --- a/src/data/splitter.py +++ b/src/data/splitter.py @@ -1,5 +1,7 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass +from enum import Enum from pathlib import Path import platform from tempfile import TemporaryDirectory @@ -10,7 +12,10 @@ from conduit.data.datasets import random_split from conduit.data.datasets.utils import stratified_split from conduit.data.datasets.vision import CdtVisionDataset, ImageTform, PillowTform +from conduit.fair.data.datasets import ACSDataset +from conduit.transforms import MinMaxNormalize, TabularNormalize, ZScoreNormalize from loguru import logger +from ranzen import some import torch from torch import Tensor import torchvision.transforms as T @@ -24,14 +29,22 @@ "DataSplitter", "RandomSplitter", "SplitFromArtifact", + "TabularSplitter", "load_split_inds_from_artifact", "save_split_inds_as_artifact", ] @dataclass(eq=False) -class DataSplitter: - """How to split the data into train/test/dep.""" +class DataSplitter(ABC): + @abstractmethod + def __call__(self, dataset: D) -> TrainDepTestSplit[D]: + """Split the dataset into train/deployment/test.""" + + +@dataclass(eq=False) +class _VisionDataSplitter(DataSplitter): + """Common methods for transforming vision datasets.""" transductive: bool = False """Whether to include the test data in the pool of unlabelled data.""" @@ -133,7 +146,7 @@ def save_split_inds_as_artifact( @dataclass(eq=False) -class RandomSplitter(DataSplitter): +class RandomSplitter(_VisionDataSplitter): seed: int = 42 dep_prop: float = 0.4 test_prop: float = 0.2 @@ -259,7 +272,7 @@ def load_split_inds_from_artifact( @dataclass(eq=False, kw_only=True) -class SplitFromArtifact(DataSplitter): +class SplitFromArtifact(_VisionDataSplitter): artifact_name: str version: Optional[int] = None @@ -272,3 +285,41 @@ def split(self, dataset: D) -> TrainDepTestSplit[D]: dep_data = dataset.subset(splits["dep"]) test_data = dataset.subset(splits["test"]) return TrainDepTestSplit(train=train_data, deployment=dep_data, test=test_data) + + +class TabularTform(Enum): + zscore_normalize = (ZScoreNormalize,) + minmax_normalize = (MinMaxNormalize,) + + def __init__(self, tform: Callable[[], TabularNormalize]) -> None: + self.tf = tform + + +@dataclass(eq=False) +class TabularSplitter(DataSplitter): + """Split and transform tabular datasets.""" + + seed: int + train_props: dict[int, dict[int, float]] | None = None + dep_prop: float = 0.2 + test_prop: float = 0.1 + transform: TabularTform | None = TabularTform.zscore_normalize + + @override + def __call__(self, dataset: D) -> TrainDepTestSplit[D]: + if not isinstance(dataset, ACSDataset): + raise NotImplementedError("TabularSplitter only supports splitting of `ACSDataset`.") + + train, dep, test = dataset.subsampled_split( + train_props=None, + val_prop=self.dep_prop, + test_prop=self.test_prop, + seed=self.seed, + ) + if some(tf_type := self.transform): + tf = tf_type.tf() + train.fit_transform_(tf) + dep.transform_(tf) + test.transform_(tf) + + return TrainDepTestSplit(train=train, deployment=dep, test=test) diff --git a/src/relay/base.py b/src/relay/base.py index 4f9ccbe8..bb8bafb2 100644 --- a/src/relay/base.py +++ b/src/relay/base.py @@ -8,7 +8,7 @@ from src.data import DataModule, DataModuleConf, RandomSplitter, SplitFromArtifact from src.data.common import Dataset -from src.data.splitter import DataSplitter +from src.data.splitter import DataSplitter, TabularSplitter from src.labelling import Labeller from src.logging import WandbConf @@ -23,7 +23,11 @@ class BaseRelay: seed: int = 0 options: ClassVar[dict[str, dict[str, type]]] = { - "split": {"random": RandomSplitter, "artifact": SplitFromArtifact} + "split": { + "random": RandomSplitter, + "artifact": SplitFromArtifact, + "tabular": TabularSplitter, + } } def init_dm( diff --git a/src/relay/fs.py b/src/relay/fs.py index f9666cfe..98b80d5e 100644 --- a/src/relay/fs.py +++ b/src/relay/fs.py @@ -8,6 +8,7 @@ from src.arch.backbones.vision import DenseNet, ResNet, SimpleCNN from src.arch.predictors.fcn import Fcn from src.data import DatasetFactory, NICOPPCfg, NIHChestXRayDatasetCfg +from src.data.factories import ACSCfg from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg from src.labelling.pipeline import ( CentroidalLabelNoiser, @@ -44,6 +45,7 @@ class FsRelay(BaseRelay): options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "ds": { + "acs": ACSCfg, "cmnist": ColoredMNISTCfg, "celeba": CelebACfg, "camelyon17": Camelyon17Cfg, diff --git a/src/relay/label.py b/src/relay/label.py index 0fed1397..8e745997 100644 --- a/src/relay/label.py +++ b/src/relay/label.py @@ -3,6 +3,7 @@ from attrs import define, field from src.data.common import DatasetFactory +from src.data.factories import ACSCfg from src.data.nih import NIHChestXRayDatasetCfg from src.data.utils import resolve_device from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg @@ -31,6 +32,7 @@ class LabelRelay(BaseRelay): options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "ds": { + "acs": ACSCfg, "cmnist": ColoredMNISTCfg, "celeba": CelebACfg, "camelyon17": Camelyon17Cfg, diff --git a/src/relay/split.py b/src/relay/split.py index d73363e7..7c888110 100644 --- a/src/relay/split.py +++ b/src/relay/split.py @@ -4,7 +4,7 @@ from src.data import RandomSplitter from src.data.common import DatasetFactory -from src.data.nico_plus_plus import NICOPPCfg +from src.data.factories import NICOPPCfg from src.data.nih import NIHChestXRayDatasetCfg from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg from src.logging import WandbConf diff --git a/src/relay/supmatch.py b/src/relay/supmatch.py index 78831c26..51b7fc72 100644 --- a/src/relay/supmatch.py +++ b/src/relay/supmatch.py @@ -20,7 +20,7 @@ from src.arch.predictors.base import PredictorFactory from src.arch.predictors.fcn import Fcn, SetFcn from src.data.common import DatasetFactory -from src.data.nico_plus_plus import NICOPPCfg +from src.data.factories import ACSCfg, NICOPPCfg from src.data.nih import NIHChestXRayDatasetCfg from src.hydra_confs.datasets import Camelyon17Cfg, CelebACfg, ColoredMNISTCfg from src.labelling.pipeline import ( @@ -69,6 +69,7 @@ class SupMatchRelay(BaseRelay): options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "scorer": {"neural": NeuralScorer, "none": NullScorer}, "ds": { + "acs": ACSCfg, "cmnist": ColoredMNISTCfg, "celeba": CelebACfg, "camelyon17": Camelyon17Cfg,