Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Jul 5, 2021
2 parents a844217 + c69a46e commit 1b6f95a
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/_sources/ethicml.data.rst.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
********
Datasets
********

.. automodule:: ethicml.data

.. contents::
Expand Down
118 changes: 114 additions & 4 deletions docs/ethicml.data.html

Large diffs are not rendered by default.

46 changes: 40 additions & 6 deletions docs/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ <h2 id="_">_</h2>
<li><a href="ethicml.vision.html#ethicml.vision.DatasetWrapper.__len__">(DatasetWrapper method)</a>
</li>
<li><a href="ethicml.utility.html#ethicml.utility.DataTuple.__len__">(DataTuple method)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.__len__">(FeatureSplit method)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.util.LabelGroup.__len__">(LabelGroup method)</a>
</li>
Expand Down Expand Up @@ -347,8 +349,14 @@ <h2 id="C">C</h2>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.preprocess.calders.Calders">Calders (class in ethicml.algorithms.preprocess.calders)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.vision_data.celeba.celeba">celeba() (in module ethicml.data.vision_data.celeba)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.vision_data.celeba.CELEBA_BASE_FOLDER">CELEBA_BASE_FOLDER (in module ethicml.data.vision_data.celeba)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.vision_data.celeba.CELEBA_FILE_LIST">CELEBA_FILE_LIST (in module ethicml.data.vision_data.celeba)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.Dataset.class_labels">class_labels (Dataset property)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.clear">clear() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.util.LabelGroup.columns">columns (LabelGroup property)</a>
</li>
Expand All @@ -365,11 +373,13 @@ <h2 id="C">C</h2>
<li><a href="ethicml.metrics.html#ethicml.metrics.confusion_matrix">confusion_matrix() (in module ethicml.metrics)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.Dataset.continuous_features">continuous_features (Dataset property)</a>
</li>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.inprocess.manual.Corels">Corels (class in ethicml.algorithms.inprocess.manual)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.copy">copy() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.inprocess.manual.Corels">Corels (class in ethicml.algorithms.inprocess.manual)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.util.LabelGroup.count">count() (LabelGroup method)</a>

<ul>
Expand Down Expand Up @@ -822,6 +832,8 @@ <h2 id="F">F</h2>
<li><a href="ethicml.data.html#ethicml.data.dataset.Dataset.feature_split">feature_split (Dataset property)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.Dataset.features_to_remove">features_to_remove (Dataset property)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit">FeatureSplit (class in ethicml.data.dataset)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.Dataset.filepath">filepath (Dataset property)</a>
</li>
Expand All @@ -847,6 +859,8 @@ <h2 id="F">F</h2>
<li><a href="ethicml.utility.html#ethicml.utility.TestTuple.from_npz">(TestTuple class method)</a>
</li>
</ul></li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.fromkeys">fromkeys() (FeatureSplit method)</a>
</li>
</ul></td>
</tr></table>

Expand All @@ -856,17 +870,19 @@ <h2 id="G">G</h2>
<li><a href="ethicml.data.html#ethicml.data.vision_data.genfaces.genfaces">genfaces() (in module ethicml.data.vision_data.genfaces)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.tabular_data.german.german">german() (in module ethicml.data.tabular_data.german)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.get">get() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.evaluators.html#ethicml.evaluators.cross_validator.CVResults.get_best_in_top_k">get_best_in_top_k() (CVResults method)</a>
</li>
<li><a href="ethicml.evaluators.html#ethicml.evaluators.cross_validator.CVResults.get_best_result">get_best_result() (CVResults method)</a>
</li>
<li><a href="ethicml.preprocessing.html#ethicml.preprocessing.get_biased_and_debiased_subsets">get_biased_and_debiased_subsets() (in module ethicml.preprocessing)</a>
</li>
<li><a href="ethicml.preprocessing.html#ethicml.preprocessing.get_biased_subset">get_biased_subset() (in module ethicml.preprocessing)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.preprocessing.html#ethicml.preprocessing.get_biased_subset">get_biased_subset() (in module ethicml.preprocessing)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.lookup.get_dataset_obj_by_name">get_dataset_obj_by_name() (in module ethicml.data.lookup)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.util.get_discrete_features">get_discrete_features() (in module ethicml.data.util)</a>
Expand Down Expand Up @@ -970,6 +986,10 @@ <h2 id="I">I</h2>
</li>
</ul></li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.items">items() (FeatureSplit method)</a>
</li>
</ul></td>
</tr></table>

<h2 id="K">K</h2>
Expand All @@ -980,6 +1000,8 @@ <h2 id="K">K</h2>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.inprocess.kamishima.Kamishima">Kamishima (class in ethicml.algorithms.inprocess.kamishima)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.keys">keys() (FeatureSplit method)</a>
</li>
</ul></td>
</tr></table>
Expand Down Expand Up @@ -1299,6 +1321,10 @@ <h2 id="P">P</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.visualisation.html#ethicml.visualisation.plot_results">plot_results() (in module ethicml.visualisation)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.pop">pop() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.popitem">popitem() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.preprocessing.html#ethicml.preprocessing.LabelBinarizer.post">post() (LabelBinarizer method)</a>
</li>
Expand All @@ -1307,11 +1333,11 @@ <h2 id="P">P</h2>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.postprocess.post_algorithm.PostAlgorithm">PostAlgorithm (class in ethicml.algorithms.postprocess.post_algorithm)</a>
</li>
<li><a href="ethicml.metrics.html#ethicml.metrics.PPV">PPV (class in ethicml.metrics)</a>
</li>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.preprocess.pre_algorithm.PreAlgorithm">PreAlgorithm (class in ethicml.algorithms.preprocess.pre_algorithm)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.preprocess.pre_algorithm.PreAlgorithm">PreAlgorithm (class in ethicml.algorithms.preprocess.pre_algorithm)</a>
</li>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.preprocess.pre_algorithm.PreAlgorithmAsync">PreAlgorithmAsync (class in ethicml.algorithms.preprocess.pre_algorithm)</a>
</li>
<li><a href="ethicml.utility.html#ethicml.utility.Prediction">Prediction (class in ethicml.utility)</a>
Expand Down Expand Up @@ -1617,6 +1643,8 @@ <h2 id="S">S</h2>
<li><a href="ethicml.preprocessing.html#ethicml.preprocessing.SequentialSplit">SequentialSplit (class in ethicml.preprocessing)</a>
</li>
<li><a href="ethicml.vision.html#ethicml.vision.set_transform">set_transform() (in module ethicml.vision)</a>
</li>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.setdefault">setdefault() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.utility.html#ethicml.utility.shuffle_df">shuffle_df() (in module ethicml.utility)</a>
</li>
Expand Down Expand Up @@ -1698,13 +1726,19 @@ <h2 id="U">U</h2>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.update">update() (FeatureSplit method)</a>
</li>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.preprocess.upsampler.Upsampler">Upsampler (class in ethicml.algorithms.preprocess.upsampler)</a>
</li>
</ul></td>
</tr></table>

<h2 id="V">V</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.data.html#ethicml.data.dataset.FeatureSplit.values">values() (FeatureSplit method)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="ethicml.algorithms.html#ethicml.algorithms.preprocess.vfae.VFAE">VFAE (class in ethicml.algorithms.preprocess.vfae)</a>
</li>
Expand Down
Binary file modified docs/objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/searchindex.js

Large diffs are not rendered by default.

30 changes: 21 additions & 9 deletions ethicml/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import InitVar, dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import Literal
from typing_extensions import Literal, TypedDict

import pandas as pd

Expand All @@ -16,7 +16,15 @@
label_spec_to_feature_list,
)

__all__ = ["Dataset"]
__all__ = ["Dataset", "FeatureSplit"]


class FeatureSplit(TypedDict):
"""A dictionary of the list of columns that belong to the feature groups."""

x: List[str]
s: List[str]
y: List[str]


@dataclass
Expand Down Expand Up @@ -87,7 +95,7 @@ def features_to_remove(self) -> List[str]:
return to_remove

@property
def ordered_features(self) -> Dict[str, List[str]]:
def ordered_features(self) -> FeatureSplit:
"""Return an order features dictionary.
This should have separate entries for the features, the labels and the
Expand All @@ -102,7 +110,7 @@ def ordered_features(self) -> Dict[str, List[str]]:
}

@property
def feature_split(self) -> Dict[str, List[str]]:
def feature_split(self) -> FeatureSplit:
"""Return a feature split dictionary.
This should have separate entries for the features, the labels and the sensitive attributes.
Expand Down Expand Up @@ -139,11 +147,12 @@ def __len__(self) -> int:
"""Number of elements in the dataset."""
return self.num_samples

def load(self, ordered: bool = False) -> DataTuple:
def load(self, ordered: bool = False, labels_as_features: bool = False) -> DataTuple:
"""Load dataset from its CSV file.
Args:
ordered: if True, return features such that discrete come first, then continuous
labels_as_features: if True, the s and y labels are included in the x features
Returns:
DataTuple with dataframes of features, labels and sensitive attributes
Expand All @@ -152,15 +161,17 @@ def load(self, ordered: bool = False) -> DataTuple:
assert isinstance(dataframe, pd.DataFrame)

feature_split = self.feature_split if not ordered else self.ordered_features
feature_split_x = feature_split["x"]
if labels_as_features:
feature_split_x = feature_split["x"] + feature_split["s"] + feature_split["y"]
else:
feature_split_x = feature_split["x"]

# =========================================================================================
# Check whether we have to generate some complementary columns for binary features.
# This happens when we have for example several races: race-asian-pac-islander etc, but we
# want to have a an attribute called "race_other" that summarizes them all. Now the problem
# is that this cannot be done in the before this point, because only here have we actually
# loaded the data. So, we have to do it here, with all the information we can piece
# together.
# is that this cannot be done before this point, because only here have we actually loaded
# the data. So, we have to do it here, with all the information we can piece together.

disc_feature_groups = self.discrete_feature_groups
if disc_feature_groups is not None:
Expand Down Expand Up @@ -192,6 +203,7 @@ def load(self, ordered: bool = False) -> DataTuple:
s_data = (s_data + 1) // 2 # map from {-1, 1} to {0, 1}
y_data = (y_data + 1) // 2 # map from {-1, 1} to {0, 1}

# the following operations remove rows if a label group is not properly one-hot encoded
s_data, s_mask = self._maybe_combine_labels(s_data, label_type="s")
if s_mask is not None:
x_data = x_data.loc[s_mask].reset_index(drop=True)
Expand Down
14 changes: 8 additions & 6 deletions ethicml/data/vision_data/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..dataset import Dataset
from ..util import LabelGroup, flatten_dict, label_spec_to_feature_list

__all__ = ["CelebAttrs", "celeba"]
__all__ = ["CELEBA_BASE_FOLDER", "CELEBA_FILE_LIST", "CelebAttrs", "celeba"]


CelebAttrs = Literal[
Expand Down Expand Up @@ -54,9 +54,10 @@
"Young",
]

_BASE_FOLDER: Final = "celeba"
CELEBA_BASE_FOLDER: Final = "celeba"
"""The data is downloaded to `download_dir` / `CELEBA_BASE_FOLDER`."""

_FILE_LIST: Final = [
CELEBA_FILE_LIST: Final = [
(
"1zmsC4yvw-e089uHXj5EdP0BSZ0AlDQRR", # File ID
"00d2c5bc6d35e252742224ab0c1e8fcb", # MD5 Hash
Expand All @@ -73,6 +74,7 @@
"list_eval_partition.txt",
),
]
"""Google drive IDs, MD5 hashes and filenames for the CelebA files."""


def celeba(
Expand Down Expand Up @@ -132,7 +134,7 @@ def celeba(
assert label in discrete_features
continuous_features = ["filename"]

base = root / _BASE_FOLDER
base = root / CELEBA_BASE_FOLDER
img_dir = base / "img_align_celeba"
if download:
_download(base)
Expand Down Expand Up @@ -161,7 +163,7 @@ def _check_integrity(base: Path) -> bool:
raise RuntimeError("Need torchvision to download data.")
from torchvision.datasets.utils import check_integrity

for (_, md5, filename) in _FILE_LIST:
for (_, md5, filename) in CELEBA_FILE_LIST:
fpath = base / filename
ext = fpath.suffix
# Allow original archive to be deleted (zip and 7z)
Expand All @@ -185,7 +187,7 @@ def _download(base: Path) -> None:
print("Files already downloaded and verified")
return

for (file_id, md5, filename) in _FILE_LIST:
for (file_id, md5, filename) in CELEBA_FILE_LIST:
download_file_from_google_drive(file_id, str(base), filename, md5)

with zipfile.ZipFile(base / "img_align_celeba.zip", "r") as fhandle:
Expand Down
17 changes: 17 additions & 0 deletions tests/loading_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,23 @@ def test_celeba():
assert data.x["filename"].iloc[0] == "000001.jpg"


def test_celeba_all_attributes():
"""Test celeba with all attributes loaded into `data.x`."""
celeba_data, _ = em.celeba(download_dir="non-existent", check_integrity=False)
assert celeba_data is not None
data = celeba_data.load(labels_as_features=True)

assert celeba_data.name == "CelebA, s=Male, y=Smiling"

assert (202599, 41) == data.x.shape
assert (202599, 1) == data.s.shape
assert (202599, 1) == data.y.shape
assert "Male" in data.x.columns
assert "Smiling" in data.x.columns

assert data.x["filename"].iloc[0] == "000001.jpg"


def test_celeba_multi_s():
"""Test celeba w/ multi S."""
sens_spec = dict(em.simple_spec({"Age": ["Young"], "Gender": ["Male"]}))
Expand Down
3 changes: 0 additions & 3 deletions tests/vision_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_label_dependent_transforms(transform):
colorizer(data, labels)


@pytest.mark.slow
def test_celeba():
"""Test celeba."""
train_set = emvi.create_celeba_dataset(
Expand Down Expand Up @@ -83,7 +82,6 @@ def test_celeba():
assert torch.equal(tmp_smiling, train_set.s)


@pytest.mark.slow
def test_celeba_multi_s():
"""Test celeba."""
data = emvi.create_celeba_dataset(
Expand All @@ -103,7 +101,6 @@ def test_celeba_multi_s():
assert isinstance(data, emvi.TorchImageDataset)


@pytest.mark.slow
def test_gen_faces():
"""Test gen faces."""
train_set = emvi.create_genfaces_dataset(
Expand Down

0 comments on commit 1b6f95a

Please sign in to comment.