diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a1e0edb..652f3e4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -8,6 +8,25 @@ before_script: - mypy --version - pytest --version +test-codestyle:3.7: + stage: test + image: python:3.7 + script: + - pwd + - ls -l + - python -c "import sys;print(sys.path)" + - flake8 --config=setup.cfg + - mypy --config-file mypy.ini + +test-units:3.7: + stage: test + image: python:3.7 + script: + - pwd + - ls -l + - python -c "import sys;print(sys.path)" + - python -m pytest . + test-codestyle: stage: test script: diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dcd9e4..bf51391 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,26 +4,30 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). -## Unreleased +## [0.1.1] - 2022/01/21 -- [...] +### Changed + +- make package compatible with python 3.7 ## [0.1.0] - 2022/01/06 -- added basic quantized layers +### Added + +- basic quantized layers - QActivation - QConv - QLinear -- added several debug layers -- added resnet, lenet -- added various quantization functions +- several debug layers +- resnet, lenet +- various quantization functions - approxsign - dorefa - sign - steheaviside - swishsign -- added support for cifar10 and mnist -- adds general training script for image classification +- support for cifar10 and mnist +- general training script for image classification - result logger for csv and tensorboard - checkpoint manager - eta estimator diff --git a/examples/image_classification/image_classification.py b/examples/image_classification/image_classification.py index eda2721..32a7ee4 100644 --- a/examples/image_classification/image_classification.py +++ b/examples/image_classification/image_classification.py @@ -58,10 +58,10 @@ def main(args: argparse.Namespace, model_args: argparse.Namespace) -> None: root_directory=args.dataset_dir, download=args.download, augmentation=augmentation_level ) - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, - shuffle=True, pin_memory=True) # type: ignore - test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, - shuffle=False, pin_memory=True) # type: ignore + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, # type: ignore + num_workers=args.num_workers, shuffle=True, pin_memory=True) # type: ignore + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, # type: ignore + num_workers=args.num_workers, shuffle=False, pin_memory=True) # type: ignore model_kwargs = vars(model_args) logging.debug(f"got model args as dict: {model_kwargs}") diff --git a/examples/image_classification/utils/experiment_creator.py b/examples/image_classification/utils/experiment_creator.py index 0759bf1..e266894 100644 --- a/examples/image_classification/utils/experiment_creator.py +++ b/examples/image_classification/utils/experiment_creator.py @@ -182,7 +182,12 @@ def create( logging.debug(f"copying {file_name}...") file_path = (self.project_root / Path(file_name)).resolve() if file_path.is_dir(): - shutil.copytree(str(file_path), str(code_path / file_name), dirs_exist_ok=True) + python_version = list(map(int, sys.version.split()[0].split("."))) + if python_version[0] == 3 and python_version[1] == 7: + # python3.7 does not have the dirs_exist_ok arg + shutil.copytree(str(file_path), str(code_path / file_name)) + else: + shutil.copytree(str(file_path), str(code_path / file_name), dirs_exist_ok=True) # type: ignore else: shutil.copy(str(file_path), str(code_path / file_name)) diff --git a/setup.py b/setup.py index 10fada5..fee2eac 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ def get_requirements(file_path: Union[Path, str]): "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)", "Operating System :: OS Independent", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.7", ], - python_requires='>=3.8', + python_requires='>=3.7', )