Skip to content

Commit

Permalink
Merge pull request #1398 from AntonioCarta/ci_black
Browse files Browse the repository at this point in the history
switch to black formatting
  • Loading branch information
AntonioCarta authored Jun 28, 2023
2 parents a61ae5c + 2c8aba8 commit 4c1c2f7
Show file tree
Hide file tree
Showing 276 changed files with 4,503 additions and 6,570 deletions.
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
max-line-length = 89
max-complexity = 18
select = B,C,E,F,W,T4,B9
10 changes: 10 additions & 0 deletions .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Lint

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
67 changes: 0 additions & 67 deletions .github/workflows/pep8.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ csvlogs/
docs/generated/
.fleet
pip-wheel-metadata
**/.DS_Store
**/.DS_Store
21 changes: 12 additions & 9 deletions avalanche/_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def shining_new_method():

def decorator(func):
if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = ""
else:
func.__doc__ += "\n\n"

Expand All @@ -59,6 +59,7 @@ def deprecated(version: float, reason: str):
alternative
:return:
"""

def decorator(func):
if inspect.isclass(func):
msg_prefix = "Call to deprecated class {name}"
Expand All @@ -69,21 +70,23 @@ def decorator(func):
msg = msg_prefix + msg_suffix

if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = ""
else:
func.__doc__ += "\n\n"

func.__doc__ += "Warning: Deprecated" + msg_suffix.format(
name=func.__name__, version=version, reason=reason)
name=func.__name__, version=version, reason=reason
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn(msg.format(name=func.__name__, version=version,
reason=reason),
category=DeprecationWarning,
stacklevel=2)
warnings.simplefilter('default', DeprecationWarning)
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
msg.format(name=func.__name__, version=version, reason=reason),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func(*args, **kwargs)

return wrapper
Expand Down
15 changes: 4 additions & 11 deletions avalanche/benchmarks/classic/ccifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,21 @@
check_vision_benchmark,
)

from avalanche.benchmarks.datasets.external_datasets.cifar import \
get_cifar10_dataset
from avalanche.benchmarks.datasets.external_datasets.cifar import get_cifar10_dataset

_default_cifar10_train_transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

_default_cifar10_eval_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

Expand Down Expand Up @@ -158,6 +153,4 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
check_vision_benchmark(benchmark_instance)
sys.exit(0)

__all__ = [
"SplitCIFAR10"
]
__all__ = ["SplitCIFAR10"]
35 changes: 17 additions & 18 deletions avalanche/benchmarks/classic/ccifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
check_vision_benchmark,
)

from avalanche.benchmarks.datasets.external_datasets.cifar import \
get_cifar100_dataset, get_cifar10_dataset
from avalanche.benchmarks.utils.classification_dataset import \
concat_classification_datasets_sequentially
from avalanche.benchmarks.datasets.external_datasets.cifar import (
get_cifar100_dataset,
get_cifar10_dataset,
)
from avalanche.benchmarks.utils.classification_dataset import (
concat_classification_datasets_sequentially,
)

from avalanche.benchmarks import nc_benchmark, NCScenario

Expand All @@ -31,18 +34,14 @@
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
]
)

_default_cifar100_eval_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
),
transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
]
)

Expand Down Expand Up @@ -242,10 +241,13 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
cifar10_train, cifar10_test = get_cifar10_dataset(dataset_root_cifar10)
cifar100_train, cifar100_test = get_cifar100_dataset(dataset_root_cifar100)

cifar_10_100_train, cifar_10_100_test, _ = \
concat_classification_datasets_sequentially(
[cifar10_train, cifar100_train], [cifar10_test, cifar100_test]
)
(
cifar_10_100_train,
cifar_10_100_test,
_,
) = concat_classification_datasets_sequentially(
[cifar10_train, cifar100_train], [cifar10_test, cifar100_test]
)
# cifar10 classes
class_order = [_ for _ in range(10)]
# if a class order is defined (for cifar100) the given class labels are
Expand Down Expand Up @@ -288,7 +290,4 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
sys.exit(0)


__all__ = [
"SplitCIFAR100",
"SplitCIFAR110"
]
__all__ = ["SplitCIFAR100", "SplitCIFAR110"]
8 changes: 2 additions & 6 deletions avalanche/benchmarks/classic/ccub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,14 @@
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

_default_eval_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/classic/cfashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from avalanche.benchmarks.classic.classic_benchmarks_utils import (
check_vision_benchmark,
)
from avalanche.benchmarks.datasets.external_datasets.fmnist import \
get_fmnist_dataset
from avalanche.benchmarks.datasets.external_datasets.fmnist import get_fmnist_dataset

_default_fmnist_train_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
Expand Down
4 changes: 1 addition & 3 deletions avalanche/benchmarks/classic/cimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@

from torchvision import transforms

normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

_default_train_transform = transforms.Compose(
[
Expand Down
14 changes: 3 additions & 11 deletions avalanche/benchmarks/classic/cinaturalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

from torchvision import transforms

normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

_default_train_transform = transforms.Compose(
[
Expand Down Expand Up @@ -181,16 +179,10 @@ def _get_inaturalist_dataset(dataset_root, super_categories, download):
dataset_root = default_dataset_location("inatuarlist2018")

train_set = INATURALIST2018(
str(dataset_root),
split="train",
supcats=super_categories,
download=download
str(dataset_root), split="train", supcats=super_categories, download=download
)
test_set = INATURALIST2018(
str(dataset_root),
split="val",
supcats=super_categories,
download=download
str(dataset_root), split="val", supcats=super_categories, download=download
)

return train_set, test_set
Expand Down
7 changes: 3 additions & 4 deletions avalanche/benchmarks/classic/classic_benchmarks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
DatasetStream,
)

from avalanche.benchmarks.utils.classification_dataset import \
ClassificationDataset
from avalanche.benchmarks.utils.classification_dataset import ClassificationDataset
from avalanche.benchmarks.utils.data import AvalancheDataset


def check_vision_benchmark(
benchmark_instance: DatasetScenario,
show_without_transforms=True):
benchmark_instance: DatasetScenario, show_without_transforms=True
):
from matplotlib import pyplot as plt
from torch.utils.data.dataloader import DataLoader

Expand Down
5 changes: 2 additions & 3 deletions avalanche/benchmarks/classic/clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def CLEAR(

if evaluation_protocol == "streaming":
assert seed is None, (
"Seed for train/test split is not required "
"under streaming protocol"
"Seed for train/test split is not required " "under streaming protocol"
)
train_split = "all"
test_split = "all"
Expand Down Expand Up @@ -300,7 +299,7 @@ def backward_transfer(self, matrix):
seed_list = [None]
else:
seed_list = SEED_LIST

for f in [None] + CLEAR_FEATURE_TYPES[data_name]:
t = transform if f is None else None
for seed in seed_list:
Expand Down
Loading

0 comments on commit 4c1c2f7

Please sign in to comment.