diff --git a/poetry.lock b/poetry.lock index 9399af71..c2247fd2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -532,30 +532,36 @@ files = [ [[package]] name = "ethicml" -version = "1.2.1" +version = "1.3.0" description = "EthicML is a library for performing and assessing algorithmic fairness. Unlike other libraries, EthicML isn't an education tool, but rather a researcher's toolkit." optional = false -python-versions = ">=3.8,<3.12" +python-versions = ">=3.10,<3.13" files = [ - {file = "ethicml-1.2.1-py3-none-any.whl", hash = "sha256:673aa63bb1dedea8968f83083fb660c7d26a2f4dabe0ad6a759d649f2f90918f"}, - {file = "ethicml-1.2.1.tar.gz", hash = "sha256:684c36a96451162d33ef462fe9ec7e528a3dcb4436bab662f454ba6b5323a9a2"}, + {file = "ethicml-1.3.0-py3-none-any.whl", hash = "sha256:5d6c4602890442968e30048a9eca46fbc030cde698cbe595710962b904e41c22"}, + {file = "ethicml-1.3.0.tar.gz", hash = "sha256:d10a56de3c629220c7ede9c2533308ed58414547331b917be2b5050134acd3b2"}, ] [package.dependencies] +filelock = "*" +jinja2 = "*" joblib = ">=1.1.0,<2.0.0" +networkx = "*" numpy = ">=1.23.2" pandas = ">=1.5.0" +pillow = ">=8.4.0" ranzen = ">=2.0.1,<2.1.0 || >2.1.0,<3.0.0" +requests = "*" scikit-learn = {version = ">=0.20.1", optional = true, markers = "extra == \"metrics\" or extra == \"models\" or extra == \"all\""} +sympy = ">=1.12,<2.0" teext = ">=0.1.3,<0.2.0" typing-extensions = ">=4.5" [package.extras] -all = ["GitPython (>=3.1.20,<4.0.0)", "cloudpickle (>=2.0.0,<3.0.0)", "fairlearn (==0.8.0)", "folktables (>=0.0.11,<0.0.12)", "gitdb2 (==4.0.2)", "matplotlib (>=3.0.2)", "pdm (>=2.4.0,<3.0.0)", "scikit-learn (>=0.20.1)", "scipy (>=1.7.2,<2.0.0)", "seaborn (>=0.9.0)", "smmap2 (==3.0.1)"] -data = ["folktables (>=0.0.11,<0.0.12)"] +all = ["GitPython (>=3.1.20,<4.0.0)", "cloudpickle (>=2.0.0,<3.0.0)", "fairlearn (==0.8.0)", "folktables (>=0.0.12)", "gitdb2 (==4.0.2)", "matplotlib (>=3.8)", "pdm (>=2.4.0,<3.0.0)", "scikit-learn (>=0.20.1)", "scipy (>=1.7.2,<2.0.0)", "seaborn (>=0.9.0)", "smmap2 (==3.0.1)"] +data = ["folktables (>=0.0.12)"] metrics = ["scikit-learn (>=0.20.1)"] models = ["GitPython (>=3.1.20,<4.0.0)", "cloudpickle (>=2.0.0,<3.0.0)", "fairlearn (==0.8.0)", "gitdb2 (==4.0.2)", "pdm (>=2.4.0,<3.0.0)", "scikit-learn (>=0.20.1)", "scipy (>=1.7.2,<2.0.0)", "smmap2 (==3.0.1)"] -plot = ["matplotlib (>=3.0.2)", "seaborn (>=0.9.0)"] +plot = ["matplotlib (>=3.8)", "seaborn (>=0.9.0)"] [[package]] name = "filelock" @@ -1962,6 +1968,21 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-type-stubs" +version = "0.1.6.dev0" +description = "A set of type stubs for popular Python packages." +optional = false +python-versions = ">=3.8" +files = [] +develop = false + +[package.source] +type = "git" +url = "https://github.com/wearepal/python-type-stubs.git" +reference = "8d5f608" +resolved_reference = "8d5f6089b412f1e013e0bf90bdd14c4a0fb11e6d" + [[package]] name = "pytorch-lightning" version = "2.1.0" @@ -2109,24 +2130,24 @@ typing-extensions = "*" [[package]] name = "ranzen" -version = "2.1.2" +version = "2.4.2" description = "A toolkit facilitating machine-learning experimentation." optional = false -python-versions = ">=3.8.0,<3.12" +python-versions = ">=3.10,<3.13" files = [ - {file = "ranzen-2.1.2-py3-none-any.whl", hash = "sha256:514341856155e70f7ceb03e62716f16d4e6441379e7c21fd3ad968a4867480c6"}, - {file = "ranzen-2.1.2.tar.gz", hash = "sha256:295b894aacb9c05cd6863d9f3dff5ea0e9addfc2cbf587ab4b06e922d2670fcc"}, + {file = "ranzen-2.4.2-py3-none-any.whl", hash = "sha256:ec16544b3f996f9d7d83ac4427b21653769935784541698ef689583f1a936366"}, + {file = "ranzen-2.4.2.tar.gz", hash = "sha256:295a00c3e97a7e6ec5de63b2a0fe6879ef142d516d942eb52471ed95590da67a"}, ] [package.dependencies] +numpy = ">=1.23.2,<2.0.0" typing-extensions = ">=4.5.0" [package.extras] -all = ["hydra-core (>=1.3.0,<2.0.0)", "loguru (>=0.6.0,<0.7.0)", "neoconfigen (>=2.3.3)", "numpy (>=1.23.2,<2.0.0)", "pandas (>=1.5.0,<2.0.0)", "torch (>=1.12.1)", "wandb (>=0.12,<0.14)"] -hydra = ["hydra-core (>=1.3.0,<2.0.0)", "neoconfigen (>=2.3.3)"] +all = ["attrs (>=23.1.0,<24.0.0)", "hydra-core (>=1.3.0,<2.0.0)", "loguru (>=0.6.0,<0.7.0)", "neoconfigen (>=2.4.2,<3.0.0)", "pandas (>=1.5.0,<2.0.0)", "wandb (>=0.15)"] +hydra = ["attrs (>=23.1.0,<24.0.0)", "hydra-core (>=1.3.0,<2.0.0)", "neoconfigen (>=2.4.2,<3.0.0)"] logging = ["loguru (>=0.6.0,<0.7.0)"] -torch = ["numpy (>=1.23.2,<2.0.0)", "torch (>=1.12.1)"] -wandb = ["pandas (>=1.5.0,<2.0.0)", "wandb (>=0.12,<0.14)"] +wandb = ["pandas (>=1.5.0,<2.0.0)", "wandb (>=0.15)"] [[package]] name = "regex" @@ -2366,41 +2387,45 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( [[package]] name = "scipy" -version = "1.10.1" +version = "1.12.0" description = "Fundamental algorithms for scientific computing in Python" optional = false -python-versions = "<3.12,>=3.8" -files = [ - {file = "scipy-1.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7354fd7527a4b0377ce55f286805b34e8c54b91be865bac273f527e1b839019"}, - {file = "scipy-1.10.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4b3f429188c66603a1a5c549fb414e4d3bdc2a24792e061ffbd607d3d75fd84e"}, - {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1553b5dcddd64ba9a0d95355e63fe6c3fc303a8fd77c7bc91e77d61363f7433f"}, - {file = "scipy-1.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c0ff64b06b10e35215abce517252b375e580a6125fd5fdf6421b98efbefb2d2"}, - {file = "scipy-1.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:fae8a7b898c42dffe3f7361c40d5952b6bf32d10c4569098d276b4c547905ee1"}, - {file = "scipy-1.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f1564ea217e82c1bbe75ddf7285ba0709ecd503f048cb1236ae9995f64217bd"}, - {file = "scipy-1.10.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d925fa1c81b772882aa55bcc10bf88324dadb66ff85d548c71515f6689c6dac5"}, - {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaea0a6be54462ec027de54fca511540980d1e9eea68b2d5c1dbfe084797be35"}, - {file = "scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d"}, - {file = "scipy-1.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:43b8e0bcb877faf0abfb613d51026cd5cc78918e9530e375727bf0625c82788f"}, - {file = "scipy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5678f88c68ea866ed9ebe3a989091088553ba12c6090244fdae3e467b1139c35"}, - {file = "scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:39becb03541f9e58243f4197584286e339029e8908c46f7221abeea4b749fa88"}, - {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bce5869c8d68cf383ce240e44c1d9ae7c06078a9396df68ce88a1230f93a30c1"}, - {file = "scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07c3457ce0b3ad5124f98a86533106b643dd811dd61b548e78cf4c8786652f6f"}, - {file = "scipy-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:049a8bbf0ad95277ffba9b3b7d23e5369cc39e66406d60422c8cfef40ccc8415"}, - {file = "scipy-1.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cd9f1027ff30d90618914a64ca9b1a77a431159df0e2a195d8a9e8a04c78abf9"}, - {file = "scipy-1.10.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:79c8e5a6c6ffaf3a2262ef1be1e108a035cf4f05c14df56057b64acc5bebffb6"}, - {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51af417a000d2dbe1ec6c372dfe688e041a7084da4fdd350aeb139bd3fb55353"}, - {file = "scipy-1.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b4735d6c28aad3cdcf52117e0e91d6b39acd4272f3f5cd9907c24ee931ad601"}, - {file = "scipy-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ff7f37b1bf4417baca958d254e8e2875d0cc23aaadbe65b3d5b3077b0eb23ea"}, - {file = "scipy-1.10.1.tar.gz", hash = "sha256:2cf9dfb80a7b4589ba4c40ce7588986d6d5cebc5457cad2c2880f6bc2d42f3a5"}, +python-versions = ">=3.9" +files = [ + {file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"}, + {file = "scipy-1.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5f00ebaf8de24d14b8449981a2842d404152774c1a1d880c901bf454cb8e2a1"}, + {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e53958531a7c695ff66c2e7bb7b79560ffdc562e2051644c5576c39ff8efb563"}, + {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e32847e08da8d895ce09d108a494d9eb78974cf6de23063f93306a3e419960c"}, + {file = "scipy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c1020cad92772bf44b8e4cdabc1df5d87376cb219742549ef69fc9fd86282dd"}, + {file = "scipy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:75ea2a144096b5e39402e2ff53a36fecfd3b960d786b7efd3c180e29c39e53f2"}, + {file = "scipy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:408c68423f9de16cb9e602528be4ce0d6312b05001f3de61fe9ec8b1263cad08"}, + {file = "scipy-1.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5adfad5dbf0163397beb4aca679187d24aec085343755fcdbdeb32b3679f254c"}, + {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3003652496f6e7c387b1cf63f4bb720951cfa18907e998ea551e6de51a04467"}, + {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b8066bce124ee5531d12a74b617d9ac0ea59245246410e19bca549656d9a40a"}, + {file = "scipy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8bee4993817e204d761dba10dbab0774ba5a8612e57e81319ea04d84945375ba"}, + {file = "scipy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a24024d45ce9a675c1fb8494e8e5244efea1c7a09c60beb1eeb80373d0fecc70"}, + {file = "scipy-1.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e7e76cc48638228212c747ada851ef355c2bb5e7f939e10952bc504c11f4e372"}, + {file = "scipy-1.12.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f7ce148dffcd64ade37b2df9315541f9adad6efcaa86866ee7dd5db0c8f041c3"}, + {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c39f92041f490422924dfdb782527a4abddf4707616e07b021de33467f917bc"}, + {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7ebda398f86e56178c2fa94cad15bf457a218a54a35c2a7b4490b9f9cb2676c"}, + {file = "scipy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:95e5c750d55cf518c398a8240571b0e0782c2d5a703250872f36eaf737751338"}, + {file = "scipy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:e646d8571804a304e1da01040d21577685ce8e2db08ac58e543eaca063453e1c"}, + {file = "scipy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:913d6e7956c3a671de3b05ccb66b11bc293f56bfdef040583a7221d9e22a2e35"}, + {file = "scipy-1.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba1b0c7256ad75401c73e4b3cf09d1f176e9bd4248f0d3112170fb2ec4db067"}, + {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:730badef9b827b368f351eacae2e82da414e13cf8bd5051b4bdfd720271a5371"}, + {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6546dc2c11a9df6926afcbdd8a3edec28566e4e785b915e849348c6dd9f3f490"}, + {file = "scipy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:196ebad3a4882081f62a5bf4aeb7326aa34b110e533aab23e4374fcccb0890dc"}, + {file = "scipy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:b360f1b6b2f742781299514e99ff560d1fe9bd1bff2712894b52abe528d1fd1e"}, + {file = "scipy-1.12.0.tar.gz", hash = "sha256:4bf5abab8a36d20193c698b0f1fc282c1d083c94723902c447e5d2f1780936a3"}, ] [package.dependencies] -numpy = ">=1.19.5,<1.27.0" +numpy = ">=1.22.4,<1.29.0" [package.extras] -dev = ["click", "doit (>=0.36.0)", "flake8", "mypy", "pycodestyle", "pydevtool", "rich-click", "typing_extensions"] -doc = ["matplotlib (>2)", "numpydoc", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] +test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "seaborn" @@ -2830,36 +2855,30 @@ opt-einsum = ["opt-einsum (>=3.3)"] [[package]] name = "torch-conduit" -version = "0.3.5" +version = "0.4.1" description = "Lightweight framework for dataloading with PyTorch and channeling the power of PyTorch Lightning" optional = false -python-versions = ">=3.8.0,<3.12" +python-versions = ">=3.10,<3.13" files = [ - {file = "torch_conduit-0.3.5-py3-none-any.whl", hash = "sha256:773ab15c9ddabf4194195edaab23794a47eebc6f2661468815ab521ded00bde3"}, - {file = "torch_conduit-0.3.5.tar.gz", hash = "sha256:5b9081761ea8d110d03200eb4da7e9972829d1bf41028b1b1192a0c3ad4b9348"}, + {file = "torch_conduit-0.4.1-py3-none-any.whl", hash = "sha256:d9e61232a7a017fd1ed7a9a0e080fb321b10ea816608518523b9c8ae9c9db44a"}, + {file = "torch_conduit-0.4.1.tar.gz", hash = "sha256:4201f46355f7397c6e0bc1b39442b83daf207a9427ed93eebf0c1cbf59f67bd2"}, ] [package.dependencies] albumentations = {version = ">=1.0.0,<2.0.0", optional = true, markers = "extra == \"image\" or extra == \"all\""} attrs = ">=21.2.0" -filelock = "*" -jinja2 = "*" -networkx = "*" numpy = ">=1.22.3,<2.0.0" opencv-python = {version = ">=4.5.3,<5.0.0", optional = true, markers = "extra == \"image\" or extra == \"all\""} pandas = ">=1.3.3,<3.0" -pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -ranzen = ">=2.0.0" -requests = "*" +ranzen = ">=2.4.1" scikit-learn = ">=1.2.0,<2.0.0" -sympy = "*" typing-extensions = ">=4.4.0" [package.extras] -all = ["albumentations (>=1.0.0,<2.0.0)", "ethicml[data] (>=1.2.1,<2.0.0)", "gdown (>=3.13.0,<4.0.0)", "hydra-core (>=1.1.1,<2.0.0)", "kaggle (>=1.5.12,<2.0.0)", "opencv-python (>=4.5.3,<5.0.0)", "rich (>=12.5.1,<13.0.0)", "soundfile", "sox"] +all = ["albumentations (>=1.0.0,<2.0.0)", "ethicml[data] (>=1.2.1,<2.0.0)", "folktables (>=0.0.12,<0.0.13)", "gdown (>=3.13.0,<4.0.0)", "hydra-core (>=1.1.1,<2.0.0)", "kaggle (>=1.5.12,<2.0.0)", "opencv-python (>=4.5.3,<5.0.0)", "rich (>=12.5.1,<13.0.0)", "soundfile", "sox"] audio = ["soundfile", "sox"] download = ["gdown (>=3.13.0,<4.0.0)", "kaggle (>=1.5.12,<2.0.0)"] -fair = ["ethicml[data] (>=1.2.1,<2.0.0)"] +fair = ["ethicml[data] (>=1.2.1,<2.0.0)", "folktables (>=0.0.12,<0.0.13)"] hydra = ["hydra-core (>=1.1.1,<2.0.0)"] image = ["albumentations (>=1.0.0,<2.0.0)", "opencv-python (>=4.5.3,<5.0.0)"] logging = ["rich (>=12.5.1,<13.0.0)"] @@ -3218,5 +3237,5 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" -python-versions = ">=3.10,<3.12" -content-hash = "ab002e6142da0d933f3da4085b997984ea27afb49ad5e2206fdff392b66935a6" +python-versions = ">=3.10,<3.13" +content-hash = "d9315641cde83b4dfcdabdeb8e6288e4f2f0ba8737e29ad9b11071d921c39d3a" diff --git a/pyproject.toml b/pyproject.toml index 3abeeb09..4fa176bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,13 +24,13 @@ neoconfigen = ">=2.3.3" numpy = { version = ">=1.23.2" } pandas = { version = ">=1.5.0" } pillow = "*" -python = ">=3.10,<3.12" +python = ">=3.10,<3.13" ranzen = { version = "^2.1.2" } scikit-image = ">=0.14" scikit_learn = { version = ">=0.20.1" } scipy = { version = ">=1.2.1" } seaborn = { version = ">=0.9.0" } -torch-conduit = { version = "^0.3.4", extras = ["image"] } +torch-conduit = { version = ">=0.3.4", extras = ["image"] } tqdm = { version = ">=4.31.1" } typer = "*" @@ -56,6 +56,7 @@ torchvision = ">=0.15.2" ruff = "*" types-tqdm = "*" pandas-stubs = "*" +python-type-stubs = {git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608"} [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/src/algs/adv/scorer.py b/src/algs/adv/scorer.py index 80f81996..ad3dd720 100644 --- a/src/algs/adv/scorer.py +++ b/src/algs/adv/scorer.py @@ -7,6 +7,7 @@ from conduit.data.datasets import CdtDataLoader, CdtDataset import conduit.metrics as cdtm from conduit.models.utils import prefix_keys +from conduit.types import Loss from loguru import logger from ranzen.misc import gcopy from ranzen.torch.loss import CrossEntropyLoss, ReductionType @@ -141,9 +142,8 @@ def run( score = recon_score = self.recon_score_w * 0.5 * (recon_score_tr + recon_score_dep) logger.info(f"Aggregate reconstruction score: {recon_score}") - classifier = SetClassifier( - model=disc, opt=self.opt, criterion=CrossEntropyLoss(reduction=ReductionType.mean) - ) + cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.mean) # type: ignore + classifier = SetClassifier(model=disc, opt=self.opt, criterion=cross_entropy) logger.info("Training invariance-scorer") classifier.fit( dm.train_dataloader(batch_size=self.batch_size_tr), diff --git a/src/algs/fs/dro.py b/src/algs/fs/dro.py index 7507ad82..05bcd4da 100644 --- a/src/algs/fs/dro.py +++ b/src/algs/fs/dro.py @@ -29,7 +29,8 @@ def __init__( reduction = str_to_enum(str_=reduction, enum=ReductionType) self.reduction = reduction if loss_fn is None: - loss_fn = CrossEntropyLoss(reduction=ReductionType.none) + cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.none) # type: ignore + loss_fn = cross_entropy else: loss_fn.reduction = ReductionType.none self.reduction = reduction @@ -37,7 +38,7 @@ def __init__( self.eta = eta @override - def forward(self, input: Tensor, *, target: Tensor) -> Tensor: # type: ignore + def forward(self, input: Tensor, *, target: Tensor) -> Tensor: sample_losses = (self.loss_fn(input, target=target) - self.eta).relu().pow(2) return reduce(sample_losses, reduction_type=self.reduction) diff --git a/src/algs/fs/lff.py b/src/algs/fs/lff.py index 56e13709..08b9dac2 100644 --- a/src/algs/fs/lff.py +++ b/src/algs/fs/lff.py @@ -5,8 +5,9 @@ from typing_extensions import Self, override from conduit.data.datasets.base import CdtDataset -from conduit.data.structures import XI, LoadedData, SampleBase, SizedDataset, TernarySample, X +from conduit.data.structures import LoadedData, SampleBase, SizedDataset, TernarySample, X from conduit.types import Indexable, IndexType +import numpy as np from ranzen.misc import gcopy from ranzen.torch import CrossEntropyLoss import torch @@ -74,7 +75,8 @@ def __add__(self, other: Self) -> Self: return copy @override - def __getitem__(self: "IndexedSample[XI]", index: IndexType) -> "IndexedSample[XI]": + def __getitem__(self, index: IndexType) -> Self: + assert isinstance(self.x, (Tensor, np.ndarray)), "x is not indexable" return gcopy( self, deep=False, x=self.x[index], y=self.y[index], s=self.s[index], idx=self.idx[index] ) @@ -101,7 +103,7 @@ def __len__(self) -> int: @dataclass(kw_only=True, repr=False, eq=False, frozen=True) class LfFClassifier(Classifier): - criterion: CrossEntropyLoss + criterion: CrossEntropyLoss # type: ignore sample_loss_ema_b: LabelEma sample_loss_ema_d: LabelEma q: float = 0.7 diff --git a/src/algs/fs/sd.py b/src/algs/fs/sd.py index be46b160..71df51c6 100644 --- a/src/algs/fs/sd.py +++ b/src/algs/fs/sd.py @@ -33,7 +33,8 @@ def __init__( if isinstance(gamma, ListConfig): gamma = list(gamma) if loss_fn is None: - loss_fn = CrossEntropyLoss(reduction=ReductionType.mean) + cross_entropy: Loss = CrossEntropyLoss(reduction=ReductionType.mean) # type: ignore + loss_fn = cross_entropy self.loss_fn = loss_fn if isinstance(lambda_, (tuple, list)): self.register_buffer("lambda_", torch.as_tensor(lambda_, dtype=torch.float)) @@ -49,7 +50,7 @@ def reduction(self) -> Union[ReductionType, str]: return self.loss_fn.reduction @reduction.setter - def reduction(self, value: Union[ReductionType, str]) -> None: + def reduction(self, value: Union[ReductionType, str]) -> None: # type: ignore self.loss_fn.reduction = value @override diff --git a/src/arch/autoencoder/resnet.py b/src/arch/autoencoder/resnet.py index 8d9a29d2..40de3ca0 100644 --- a/src/arch/autoencoder/resnet.py +++ b/src/arch/autoencoder/resnet.py @@ -22,7 +22,7 @@ def __init__( super().__init__() self.size, self.scale_factor = size, scale_factor - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: return F.interpolate(x, size=self.size, scale_factor=self.scale_factor) @@ -72,7 +72,7 @@ def __init__( self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -112,7 +112,7 @@ def __init__( self.downsample = downsample self.stride = stride - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -150,7 +150,7 @@ def __init__( self.bn2 = nn.BatchNorm2d(planes) self.upsample = upsample - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -189,7 +189,7 @@ def __init__( self.upsample = upsample self.scale = scale - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: identity = x out = self.conv1(x) @@ -270,7 +270,7 @@ def _make_layer( return nn.Sequential(*layers) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -357,7 +357,7 @@ def _make_layer( return nn.Sequential(*layers) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: x = self.linear(x) # NOTE: replaced this by Linear(in_channels, 514 * 4 * 4) diff --git a/src/arch/predictors/set_transformer.py b/src/arch/predictors/set_transformer.py index 902fccf8..70c4ef68 100644 --- a/src/arch/predictors/set_transformer.py +++ b/src/arch/predictors/set_transformer.py @@ -19,7 +19,7 @@ def __init__(self, embed_dim: int, num_heads: int = 4) -> None: self.fc = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU(inplace=True)) self.mh = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads) - def forward(self, Q: Tensor, K: Tensor) -> Tensor: # type: ignore + def forward(self, Q: Tensor, K: Tensor) -> Tensor: H = self.norm1(K + self.mh(key=K, query=Q, value=K, need_weights=False)[0]) out = self.norm2(H + self.fc(H)) return out @@ -30,7 +30,7 @@ def __init__(self, dim_in: int, *, num_heads: int) -> None: super().__init__() self.mab = MultiheadAttentionBlock(embed_dim=dim_in, num_heads=num_heads) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: return self.mab(x, x) @@ -42,7 +42,7 @@ def __init__(self, dim_in: int, *, out_dim: int, num_heads: int, num_inds: int) self.mab1 = MultiheadAttentionBlock(embed_dim=out_dim, num_heads=num_heads) self.mab2 = MultiheadAttentionBlock(embed_dim=dim_in, num_heads=num_heads) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: H = self.mab1(self.inducing_points.repeat(x.size(0), 1, 1), x) return self.mab2(x, H) @@ -54,7 +54,7 @@ def __init__(self, dim: int, *, num_heads: int, num_seeds: int) -> None: nn.init.xavier_uniform_(self.seed_vectors) self.mab = MultiheadAttentionBlock(embed_dim=dim, num_heads=num_heads) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: return self.mab(self.seed_vectors.repeat(x.size(0), 1, 1), x) @@ -85,7 +85,7 @@ def __init__( ) self.predictor = nn.Linear(hidden_dim * num_inds, target_dim) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: out = self.embedder(x).unsqueeze(1) out = self.decoder(self.encoder(out)) out = out.flatten(start_dim=1).sum(0) diff --git a/src/data/splitter.py b/src/data/splitter.py index 952366ba..48c22a18 100644 --- a/src/data/splitter.py +++ b/src/data/splitter.py @@ -7,6 +7,7 @@ from typing_extensions import override from conduit.data.constants import IMAGENET_STATS +from conduit.data.datasets import random_split from conduit.data.datasets.utils import stratified_split from conduit.data.datasets.vision import CdtVisionDataset, ImageTform, PillowTform from loguru import logger @@ -159,8 +160,8 @@ def __post_init__(self) -> None: def split(self, dataset: D) -> TrainDepTestSplit[D]: if self.data_prop < 1: dataset = stratified_split(dataset, default_train_prop=self.data_prop).train - dep_inds, test_inds, train_inds = dataset.random_split( - props=[self.dep_prop, self.test_prop], seed=self.seed, as_indices=True + dep_inds, test_inds, train_inds = random_split( + dataset, props=[self.dep_prop, self.test_prop], seed=self.seed, as_indices=True ) train_inds = torch.as_tensor(train_inds) train_data = dataset.subset(train_inds) diff --git a/src/labelling/encoder.py b/src/labelling/encoder.py index b8232372..9a44ff4d 100644 --- a/src/labelling/encoder.py +++ b/src/labelling/encoder.py @@ -39,13 +39,13 @@ def __init__( model, self.transforms = clip.load( name=version.value, # type: ignore device="cpu", - download_root=download_root, + download_root=download_root, # type: ignore ) logger.info("Done.") self.encoder = model.visual self.out_dim = cast(int, self.encoder.output_dim) - def forward(self, x: Tensor) -> Tensor: # type: ignore + def forward(self, x: Tensor) -> Tensor: return self.encoder(x) @torch.no_grad() # pyright: ignore