diff --git a/.circleci/config.yml b/.circleci/config.yml index db2ad88..a0d4dcd 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,28 +7,35 @@ jobs: steps: - checkout - run: - name: Install dev dependencies + name: Create virtualenv command: | - python3 -m venv ~/venv - . ~/venv/bin/activate - make install + python -m venv /home/circleci/venv/ + echo "source /home/circleci/venv/bin/activate" >> $BASH_ENV + - restore_cache: + keys: + - &cache-key python-3.9-packages-v1-{{ checksum "pyproject.toml" }} + - &cache-key-prefix python-3.9-packages-v1- - run: - name: Test - command: | - . ~/venv/bin/activate - make test + name: Install dev dependencies + command: make install + - save_cache: + key: *cache-key + paths: + - "/home/circleci/venv/" + - "/home/circleci/.cache/pip" - run: - name: Lint - command: | - . ~/venv/bin/activate - make lint + name: Check formatting + command: make format_check + when: always - run: - name: Mypy - command: | - . ~/venv/bin/activate - make mypy + name: Check linting + command: make lint_check + when: always - run: - name: Black - command: | - . ~/venv/bin/activate - make black_check + name: Run tests + command: make test + when: always + - run: + name: Check Python type annotations + command: make mypy + when: always diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9ace2a..1b07370 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,22 +16,6 @@ repos: stages: [commit] - id: trailing-whitespace # Trims trailing whitespace. - - - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 - hooks: - - id: isort - - - repo: https://github.com/ambv/black - rev: 23.1.0 - hooks: - - id: black - - - repo: https://github.com/pycqa/flake8.git - rev: 6.0.0 - hooks: - - id: flake8 - - repo: https://github.com/pre-commit/mirrors-mypy rev: 'v1.1.1' hooks: diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 09325db..b168625 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,18 +6,17 @@ version: 2 # Set the version of Python +# and use modern dependency management build: os: ubuntu-22.04 apt_packages: - libmagic1 tools: python: "3.10" + jobs: + pre_build: + - "pip install '.[docs]'" # Build documentation in the docs/ directory sphinx: - configuration: docs/conf.py - -# Declare the Python requirements required to build the docs -python: - install: - - requirements: docs/requirements.txt + configuration: docs/conf.py \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 295f23a..0000000 --- a/Dockerfile +++ /dev/null @@ -1,47 +0,0 @@ -FROM python:3.9.16-slim AS base - -# Create virtualenv and add to path. -ENV VIRTUAL_ENV=/opt/venv -RUN python3 -m venv $VIRTUAL_ENV -ENV PATH="$VIRTUAL_ENV/bin:$PATH" - -# System dependencies -RUN apt-get update && apt-get install -y build-essential && rm -rf /var/lib/apt/lists/* - -WORKDIR /opt/app - -# Install Python requirements. README.md is required as the setup.py file -# refers to it. -COPY . . -RUN pip install -e .[dev,test] - -# Run subsequent commands as non-root user -ENV USER=application -RUN useradd --no-log-init --system --user-group $USER -USER $USER - -# --- - -# Create a pytest image from the base -FROM base as pytest - -# Run py.test against current dir by default but allow custom args to be passed -# in. -ENTRYPOINT ["py.test"] -CMD [""] - -# --- - -# Create a isort image from the base -FROM base as isort - -ENTRYPOINT ["isort"] -CMD ["-rc"] - -# --- - -# Create a black image from the base -FROM base as black - -ENTRYPOINT ["black"] -CMD ["."] diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 74d57a1..0000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -Sphinx==4.5.0 -myst-parser==0.18.1 -django==4.1.10 -xlrd==2.0.1 -python-magic==0.4.27 -pandas==1.5.1 -boto3==1.24.94 -botocore==1.27.94 -mypy-boto3-s3==1.26.0.post1 -openpyxl==3.0.10 -structlog==22.3 -pact-python>=1.6.0 -sphinx-rtd-theme==1.3.0 \ No newline at end of file diff --git a/docs/xocto/development.md b/docs/xocto/development.md index e3f89ec..92b37d0 100644 --- a/docs/xocto/development.md +++ b/docs/xocto/development.md @@ -1,5 +1,7 @@ # Development +This page details how to develop `xocto`. + ## Installation of development environment Create and activate a Python 3.9 virtualenv then run: @@ -8,7 +10,7 @@ Create and activate a Python 3.9 virtualenv then run: make install ``` -to install the package including development and testing dependencies +to install the package including development and testing dependencies. ## Running tests @@ -23,49 +25,41 @@ make test Use these make commands: ```sh -make lint -make black -make isort -make mypy +make format_check # Check formatting +make lint_check # Check linting +make mypy # Check Python type annotations ``` -Docker images for these jobs can be built with: +## Coding conventions -```sh -make docker_images -``` +### Don't mix code changes with version updates -This creates separate images for pytest, isort and black. Each can be run like -so: +Code changes mixed with version updates are problematic. The reason is because +of this workflow: -```sh -docker run -v `pwd`:/opt/app xocto/pytest -docker run -v `pwd`:/opt/app xocto/isort -docker run -v `pwd`:/opt/app xocto/black -``` - -## Don't mix code changes with version updates - -Code changes mixed with version updates are problematic. The reason is because of this workflow: - -1. I write a bugfix PR that also updates the version +1. I write a bug-fix PR that also updates the version 2. You add a feature PR that also updates the version 3. Everyone else mixes version changes with their code change PRs -4. My PR is accepted, now everyone else has to update the version specified in their PR +4. My PR is accepted, now everyone else has to update the version specified in + their PR -This is why typically in shared projects version releases are seperated into their own pull requests. +This is why typically in shared projects version releases are separated into +their own pull requests. ## Publishing -Before you begin, determine the release number. This follows the instructions specifiwed on [semver.org](https://semver.org/). Releases therefore use this pattern: +### Version number + +First determine the version number. Follow the instructions specified on +[semver.org](https://semver.org/) which advocates this pattern: ``` MAJOR.MINOR.PATCH ``` -Where: +where: -- MAJOR version when you make incompatible API changes +- MAJOR version when you make backwards-incompatible API changes - MINOR version when you add functionality in a backward compatible manner - PATCH version when you make backward compatible bug fixes @@ -75,9 +69,10 @@ Create a pull request that: 1. Adds release notes to `CHANGELOG.md`. -2. Updates the `VERSION` constant in `setup.py`. +2. Updates the `VERSION` constant in `pyproject.toml`. -3. Updates the `__version__` constant in `xocto/__init__.py`, following the [semver.org](https://semver.org/) specification. +3. Updates the `__version__` constant in `xocto/__init__.py`, following the + [semver.org](https://semver.org/) specification. Commit these changes in a single commit with subject matching `Bump version to v...`. diff --git a/makefile b/makefile index 292d8a1..b7188c2 100644 --- a/makefile +++ b/makefile @@ -1,41 +1,39 @@ install: - pip install pip==23.1.2 - pip install -e .[dev,test] + pip install pip==23.3.1 + pip install -e '.[dev,docs]' -clean: - @echo Cleaning workspace - -rm -rf dist/ *.egg-info/ build/ - -find . -type d -name __pycache__ -delete -# Static analysis +# CI step wrappers -lint: - make black_check ruff mypy +ci: format_check lint_check test mypy -black_check: - black --check --diff . +format_check: + ruff format --check . -ruff: +lint_check: ruff check . +test: + py.test + mypy: mypy -test: - py.test +# Local helpers + +clean: + @echo Cleaning workspace + -rm -rf dist/ *.egg-info/ build/ + -find . -type d -name __pycache__ -delete format: ruff check --fix . - black . - -docker_images: - docker build -t xocto/pytest --target=pytest . - docker build -t xocto/ruff --target=ruff . - docker build -t xocto/black --target=black . + ruff format . # Releases -VERSION=v$(shell python setup.py --version) +# Extract version from pyproject.toml +VERSION=$(shell python -c "import importlib.metadata; print(importlib.metadata.version('xocto'))") tag: @echo Tagging as $(VERSION) diff --git a/pyproject.toml b/pyproject.toml index 7da70c2..2bdacb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,83 @@ -[tool.black] -line-length = 99 +[build-system] +requires = ["setuptools>=68.2.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "xocto" +version = "4.9.0" +requires-python = ">=3.9" +description = "Kraken Technologies Python service utilities" +readme = "README.md" +authors = [ + {name = "Kraken Technologies", email = "talent@octopus.energy"}, +] +maintainers = [ + {name = "Kraken Technologies", email = "talent@octopus.energy"}, +] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +license = {text = "MIT"} +dependencies = [ + "ddtrace>=1.9.0", + "duckdb>=0.9.0", + "django>=4.0", + "openpyxl>=3.1.0", + "pact-python>=1.6.0", + "pandas>=1.5.3", + "pyarrow>=11.0.0", + "python-dateutil>=2.8.2", + "python-magic>=0.4.27", + "pytz", + "structlog>=20.2.0", + "xlrd>=2.0.1", +] + +[project.optional-dependencies] +dev = [ + "boto3==1.26.53", + "botocore==1.29.53", + "hypothesis==6.62.1", + "moto[s3,sqs]==4.1", + "mypy-boto3-s3==1.26.0.post1", + "mypy==0.991", + "numpy==1.22.2", + "pre-commit>=3.2.0", + "pyarrow-stubs==10.0.1.6", + "pytest-django==4.5.2", + "pytest-mock==3.10.0", + "pytest==7.2.1", + "ruff==0.1.3", + "time-machine==2.9.0", + "twine==4.0.2", + "types-openpyxl==3.0.4.5", + "types-python-dateutil==2.8.19.6", + "types-pytz==2022.7.1.0", + "types-requests==2.28.11.8", + "wheel==0.38.4", +] +docs = [ + "Sphinx==4.5.0", + "myst-parser==0.18.1", +] + +[project.urls] +changelog = "https://github.com/octoenergy/xocto/blob/main/CHANGELOG.md" +documentation = "https://xocto.readthedocs.io" +issues = "https://github.com/octoenergy/xocto/issues" + +[tool.setuptools] +packages = ["xocto", "xocto.events", "xocto.storage"] + + +# Mypy +# ---- [tool.mypy] # Specify which files to check. @@ -78,7 +156,7 @@ select = [ "I", # isort ] ignore = [ - "E501", # line too long - black takes care of this for us + "E501", # line too long ] [tool.ruff.per-file-ignores] @@ -101,3 +179,53 @@ section-order = [ "xocto", "tests", ] + + +# Pytest +# ------ + +[tool.pytest.ini_options] +# Convert some warning types into errors but ignore some others that we +# can't/won't fix right now. +# +# Note: +# - Each line is a colon-separated string. +# - The first part is what to do with the warning - error or ignore. +# - The second part is a regex that must match the start of the warning message. +# - The third part is the warning class name. +# - The fourth part is a regex that must match the module triggering the error. +# - The order matters. These rules get applied with the bottom rule first. +# Hence the rules ignoring deprecation warnings must by below the rule that converts +# DeprecationWarnings into errors. +filterwarnings = [ + "error::RuntimeWarning", + "error::DeprecationWarning", + "ignore:defusedxml.lxml:DeprecationWarning:zeep", + "ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3:DeprecationWarning:(graphene|singledispatch)", + # https://github.com/ktosiek/pytest-freezegun/issues/35 + "ignore:distutils Version classes are deprecated:DeprecationWarning:pytest_freezegun", + # These deprecation warnings were added in django 4.2 to warn of removal in django 5 + "ignore:The USE_DEPRECATED_PYTZ setting:django.utils.deprecation.RemovedInDjango50Warning", + "ignore:The django.utils.timezone.utc alias is deprecated:django.utils.deprecation.RemovedInDjango50Warning", + "ignore:The is_dst argument to make_aware:django.utils.deprecation.RemovedInDjango50Warning", +] + +DJANGO_SETTINGS_MODULE = "tests.settings" + +# Test modules must have be named this way. +python_files = "test_*.py" + +# Default options when pytest is run: +# +# --verbose -> Show names of tests being run. +# --tb=short -> Use short tracebacks. +# https://docs.pytest.org/en/stable/usage.html#modifying-python-traceback-printing +# --nomigrations -> Disable Django's migrations and create the database by inspecting models instead. +# https://pytest-django.readthedocs.io/en/latest/database.html#nomigrations-disable-django-migrations +# --reuse-db -> Don't remove test database after each test run so it can be re-used next time. +# https://pytest-django.readthedocs.io/en/latest/database.html#reuse-db-reuse-the-testing-database-between-test-runs +# --color=auto -> Detect whether to print colored output. +# --capture=fd -> Capture all output written to the STDOUT and STDERR file descriptors. +# https://docs.pytest.org/en/stable/capture.html +# +addopts = "--tb=short --verbose --nomigrations --reuse-db --color=auto --capture=fd" diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 67f37da..0000000 --- a/pytest.ini +++ /dev/null @@ -1,44 +0,0 @@ -[pytest] -# Convert some warning types into errors but ignore some others that we -# can't/won't fix right now. -# -# Note: -# - Each line is a colon-separated string. -# - The first part is what to do with the warning - error or ignore. -# - The second part is a regex that must match the start of the warning message. -# - The third part is the warning class name. -# - The fourth part is a regex that must match the module triggering the error. -# - The order matters. These rules get applied with the bottom rule first. -# Hence the rules ignoring deprecation warnings must by below the rule that converts -# DeprecationWarnings into errors. -filterwarnings = - error::RuntimeWarning - error::DeprecationWarning - ignore:defusedxml.lxml:DeprecationWarning:zeep - ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3:DeprecationWarning:(graphene|singledispatch) - # https://github.com/ktosiek/pytest-freezegun/issues/35 - ignore:distutils Version classes are deprecated:DeprecationWarning:pytest_freezegun - # These deprecation warnings were added in django 4.2 to warn of removal in django 5 - ignore:The USE_DEPRECATED_PYTZ setting:django.utils.deprecation.RemovedInDjango50Warning - ignore:The django.utils.timezone.utc alias is deprecated:django.utils.deprecation.RemovedInDjango50Warning - ignore:The is_dst argument to make_aware:django.utils.deprecation.RemovedInDjango50Warning - -DJANGO_SETTINGS_MODULE=tests.settings - -# Test modules must have be named this way. -python_files = test_*.py - -# Default options when pytest is run: -# -# --verbose -> Show names of tests being run. -# --tb=short -> Use short tracebacks. -# https://docs.pytest.org/en/stable/usage.html#modifying-python-traceback-printing -# --nomigrations -> Disable Django's migrations and create the database by inspecting models instead. -# https://pytest-django.readthedocs.io/en/latest/database.html#nomigrations-disable-django-migrations -# --reuse-db -> Don't remove test database after each test run so it can be re-used next time. -# https://pytest-django.readthedocs.io/en/latest/database.html#reuse-db-reuse-the-testing-database-between-test-runs -# --color=auto -> Detect whether to print colored output. -# --capture=fd -> Capture all output written to the STDOUT and STDERR file descriptors. -# https://docs.pytest.org/en/stable/capture.html -# -addopts = --tb=short --verbose --nomigrations --reuse-db --color=auto --capture=fd diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4ea536a..0000000 --- a/setup.cfg +++ /dev/null @@ -1,26 +0,0 @@ -[bdist_wheel] -universal=1 - -[flake8] -# Ignored flake8 rules -# -# E203 - Colons should not have any space before them (https://www.flake8rules.com/rules/E203.html). -# Ignoring this can make some code more readable. -# E501 - Line length should be less than 79 chars (https://www.flake8rules.com/rules/E501.html). -# We don't need flake8 to check this as black determines line formatting. -# F541 - f-strings without any placeholders (https://flake8.pycqa.org/en/latest/user/error-codes.html). -# W391 - There should be one, and only one, blank line at the end of each file (https://www.flake8rules.com/rules/W391.html). -# W503 - Line breaks should occur after the binary operator to keep all variable names aligned (https://www.flake8rules.com/rules/W503.html). -# W504 - Line breaks should occur before the binary operator to keep all operators aligned (https://www.flake8rules.com/rules/W504.html) -ignore = E203,E501,F541,W391,W503,W504,K204,K202 - -# Ignore unused imports (F401) in __init__ modules as these are convenience imports. -per-file-ignores = - */__init__.py:F401 - -# Enable log format extension checks. -# See https://github.com/globality-corp/flake8-logging-format#violations-detected -enable-extensions=G - -exclude = - .*/*.py diff --git a/setup.py b/setup.py deleted file mode 100644 index c887e02..0000000 --- a/setup.py +++ /dev/null @@ -1,83 +0,0 @@ -from codecs import open -from os import path - -from setuptools import setup - - -REPO_ROOT = path.abspath(path.dirname(__file__)) - -VERSION = "4.8.0" - -with open(path.join(REPO_ROOT, "README.md"), encoding="utf-8") as f: - long_description = f.read() - -setup( - name="xocto", - version=VERSION, - description="Kraken Technologies Python service utilities", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/octoenergy/xocto", - author="Kraken Technologies", - author_email="talent@octopus.energy", - license="MIT", - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Intended Audience :: Developers", - "Topic :: Software Development :: Build Tools", - "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - packages=["xocto", "xocto.events", "xocto.storage"], - package_data={"xocto": ["py.typed"]}, - zip_safe=False, - install_requires=[ - "ddtrace>=1.9.0", - "duckdb>=0.9.0", - "django>=4.0", - "openpyxl>=3.1.0", - "pact-python>=1.6.0", - "pandas>=1.5.3", - "pyarrow>=11.0.0", - "python-dateutil>=2.8.2", - "python-magic>=0.4.27", - "pytz", - "structlog>=20.2.0", - "xlrd>=2.0.1", - ], - extras_require={ - "dev": [ - "black==22.12.0", - "boto3==1.26.53", - "botocore==1.29.53", - "mypy-boto3-s3==1.26.0.post1", - "mypy==0.991", - "numpy==1.22.2", - "pre-commit>=3.2.0", - "pyarrow-stubs==10.0.1.6", - "ruff==0.0.292", - "twine==4.0.2", - "types-openpyxl==3.0.4.5", - "types-python-dateutil==2.8.19.6", - "types-pytz==2022.7.1.0", - "types-requests==2.28.11.8", - "wheel==0.38.4", - ], - "test": [ - "ruff==0.0.292", - "hypothesis==6.62.1", - "moto[s3,sqs]==4.1", - "pytest-django==4.5.2", - "pytest-mock==3.10.0", - "pytest==7.2.1", - "time-machine==2.9.0", - ], - }, - project_urls={ - "Documentation": "https://xocto.readthedocs.io", - "Changelog": "https://github.com/octoenergy/xocto/blob/main/CHANGELOG.md", - "Issues": "https://github.com/octoenergy/xocto/issues", - }, -) diff --git a/tests/storage/test_files.py b/tests/storage/test_files.py index a59c68c..4deca4d 100644 --- a/tests/storage/test_files.py +++ b/tests/storage/test_files.py @@ -24,16 +24,17 @@ def test_returns_size(self, file, size): def test_files_hash_correctly(): file1 = io.BytesIO(b"This is my first file") assert ( - files.hashfile(file1) == "2f1b1b913ca382ad8f992ec6a18ecedfa2fcd8ff21b0a2227614a7bd94c23d2d" + files.hashfile(file1) + == "2f1b1b913ca382ad8f992ec6a18ecedfa2fcd8ff21b0a2227614a7bd94c23d2d" ) file2 = io.BytesIO(b"And this is my second") assert ( - files.hashfile(file2) == "8cbe3eb51eec64423d2a870da81475361fa3571402fb77810db261e1920d45b4" + files.hashfile(file2) + == "8cbe3eb51eec64423d2a870da81475361fa3571402fb77810db261e1920d45b4" ) def test_convert_xlsx_file_to_csv(fixture_path): - report_filename = "Daily-report-Octopus Energy-2020-04-08" xlsx_filepath = fixture_path(f"siteworks/agent_reports/{report_filename}.xlsx") csv_filepath = fixture_path(f"siteworks/agent_reports/{report_filename}.csv") @@ -58,7 +59,6 @@ def test_convert_xlsx_file_to_csv(fixture_path): def test_convert_xls_file_to_csv(fixture_path): - report_filename = "Daily-report-Octopus Energy-2019-08-15" xls_filepath = fixture_path(f"siteworks/agent_reports/{report_filename}.xls") csv_filepath = fixture_path(f"siteworks/agent_reports/{report_filename}.csv") diff --git a/tests/storage/test_storage.py b/tests/storage/test_storage.py index 7d2fbe2..291078d 100644 --- a/tests/storage/test_storage.py +++ b/tests/storage/test_storage.py @@ -20,7 +20,9 @@ @pytest.fixture def mock_s3_bucket(mocker): with moto.mock_s3(): - bucket = boto3.resource("s3", region_name="us-east-1").create_bucket(Bucket="some-bucket") + bucket = boto3.resource("s3", region_name="us-east-1").create_bucket( + Bucket="some-bucket" + ) client = boto3.client("s3") mocker.patch.object( @@ -63,11 +65,16 @@ def test_make_key_path_raises_error_when_exceeds_max_length(self): ), ], ) - def test_make_key_path_with_use_date_in_key_path(self, namespace, filepath, expected): + def test_make_key_path_with_use_date_in_key_path( + self, namespace, filepath, expected + ): s3_file_store = storage.S3SubdirectoryFileStore( "s3://some-bucket/folder?use_date_in_key_path=1" ) - assert s3_file_store.make_key_path(namespace=namespace, filepath=filepath) == expected + assert ( + s3_file_store.make_key_path(namespace=namespace, filepath=filepath) + == expected + ) @time_machine.travel("2021-09-10", tick=False) @pytest.mark.parametrize( @@ -75,12 +82,21 @@ def test_make_key_path_with_use_date_in_key_path(self, namespace, filepath, expe [ ("", "file.txt", "folder/file.txt"), ("namespace", "file.txt", "folder/namespace/file.txt"), - ("namespace/sub-namespace", "file.txt", "folder/namespace/sub-namespace/file.txt"), + ( + "namespace/sub-namespace", + "file.txt", + "folder/namespace/sub-namespace/file.txt", + ), ], ) - def test_make_key_path_without_use_date_in_key_path(self, namespace, filepath, expected): + def test_make_key_path_without_use_date_in_key_path( + self, namespace, filepath, expected + ): s3_file_store = storage.S3SubdirectoryFileStore("s3://some-bucket/folder") - assert s3_file_store.make_key_path(namespace=namespace, filepath=filepath) == expected + assert ( + s3_file_store.make_key_path(namespace=namespace, filepath=filepath) + == expected + ) @mock.patch.object(storage.S3FileStore, "_get_boto_client") def test_fetch_url(self, get_boto_client): @@ -104,7 +120,11 @@ def test_fetch_url_with_version(self, get_boto_client): # Should be called including the subdirectory path. get_boto_client.return_value.generate_presigned_url.assert_called_once_with( "get_object", - Params={"Bucket": "some-bucket", "Key": "folder/a/b.txt", "VersionId": "some-version"}, + Params={ + "Bucket": "some-bucket", + "Key": "folder/a/b.txt", + "VersionId": "some-version", + }, ExpiresIn=mock.ANY, ) @@ -113,9 +133,14 @@ def test_list_s3_keys_page(self, mock_s3_bucket): filenames = [f"file_{i:04}.pdf" for i in range(105)] for filename in filenames: store.store_file( - namespace="some/path/", filename=filename, contents=f"{filename} content" + namespace="some/path/", + filename=filename, + contents=f"{filename} content", ) - expected = [storage.S3Object("some-bucket", f"path/{filename}") for filename in filenames] + expected = [ + storage.S3Object("some-bucket", f"path/{filename}") + for filename in filenames + ] # "file_00" excludes file_0100.pdf and above store = storage.S3SubdirectoryFileStore("s3://some-bucket/some") @@ -146,7 +171,9 @@ def test_list_files(self, get_boto_bucket): "a/b/bar.txt", ] # Should be called including the subdirectory path. - get_boto_bucket.return_value.objects.filter.assert_called_once_with(Prefix="folder/a/b") + get_boto_bucket.return_value.objects.filter.assert_called_once_with( + Prefix="folder/a/b" + ) @mock.patch.object(storage.S3FileStore, "_get_boto_object") def test_fetch_file_fetches_given_path(self, get_boto_object): @@ -173,23 +200,32 @@ def test_fetch_file_fetches_given_path(self, get_boto_object): class TestS3FileStore: @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_stores_file_that_does_not_exist(self, get_boto_client, get_boto_object_for_key): + def test_stores_file_that_does_not_exist( + self, get_boto_client, get_boto_object_for_key + ): get_boto_object_for_key.side_effect = storage.KeyDoesNotExist store = storage.S3FileStore("bucket") - store.store_file(namespace="files", filename="file.pdf", contents="some-content") + store.store_file( + namespace="files", filename="file.pdf", contents="some-content" + ) s3_client = get_boto_client.return_value s3_client.upload_fileobj.assert_called_once() @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_overwrites_file_that_does_exist(self, get_boto_client, get_boto_object_for_key): + def test_overwrites_file_that_does_exist( + self, get_boto_client, get_boto_object_for_key + ): get_boto_object_for_key.return_value = mock.Mock() store = storage.S3FileStore("bucket") store.store_file( - namespace="files", filename="file.pdf", contents="some-content", overwrite=True + namespace="files", + filename="file.pdf", + contents="some-content", + overwrite=True, ) s3_client = get_boto_client.return_value @@ -197,16 +233,22 @@ def test_overwrites_file_that_does_exist(self, get_boto_client, get_boto_object_ @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_raises_error_for_file_that_does_exist(self, get_boto_client, get_boto_object_for_key): + def test_raises_error_for_file_that_does_exist( + self, get_boto_client, get_boto_object_for_key + ): get_boto_object_for_key.return_value = mock.Mock() store = storage.S3FileStore("bucket") with pytest.raises(storage.FileExists): - store.store_file(namespace="files", filename="file.pdf", contents="some-content") + store.store_file( + namespace="files", filename="file.pdf", contents="some-content" + ) @mock.patch.object(storage.S3FileStore, "_bucket_is_versioned") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_stores_file_in_versioned_bucket(self, get_boto_client, get_bucket_is_versioned): + def test_stores_file_in_versioned_bucket( + self, get_boto_client, get_bucket_is_versioned + ): bucket_name = "bucket" namespace = "files" filename = "file.pdf" @@ -227,17 +269,23 @@ def test_stores_file_in_versioned_bucket(self, get_boto_client, get_bucket_is_ve @mock.patch.object(storage.S3FileStore, "_bucket_is_versioned") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_raises_error_for_unversioned_bucket(self, get_boto_client, get_bucket_is_versioned): + def test_raises_error_for_unversioned_bucket( + self, get_boto_client, get_bucket_is_versioned + ): get_boto_client.return_value = mock.Mock() get_bucket_is_versioned.return_value = False store = storage.S3FileStore("bucket", use_date_in_key_path=False) with pytest.raises(storage.BucketNotVersioned): - store.store_versioned_file(key_path="files/file.pdf", contents="some-content") + store.store_versioned_file( + key_path="files/file.pdf", contents="some-content" + ) @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_stores_filepath_that_does_not_exist(self, get_boto_client, get_boto_object_for_key): + def test_stores_filepath_that_does_not_exist( + self, get_boto_client, get_boto_object_for_key + ): get_boto_object_for_key.side_effect = storage.KeyDoesNotExist store = storage.S3FileStore("bucket") @@ -248,7 +296,9 @@ def test_stores_filepath_that_does_not_exist(self, get_boto_client, get_boto_obj @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @mock.patch.object(storage.S3FileStore, "_get_boto_client") - def test_overwrites_filepath_that_does_exist(self, get_boto_client, get_boto_object_for_key): + def test_overwrites_filepath_that_does_exist( + self, get_boto_client, get_boto_object_for_key + ): get_boto_object_for_key.return_value = mock.Mock() store = storage.S3FileStore("bucket") @@ -265,7 +315,10 @@ def test_adds_metadata(self, get_boto_client, get_boto_object_for_key): metadata = {"some": "metadata"} store.store_file( - namespace="files", filename="file.pdf", contents="some-content", metadata=metadata + namespace="files", + filename="file.pdf", + contents="some-content", + metadata=metadata, ) s3_client = get_boto_client.return_value @@ -336,11 +389,15 @@ def test_s3_file_store_bucket_length(self): with pytest.raises(ValueError): storage.S3FileStore("ab") with pytest.raises(ValueError): - storage.S3FileStore("loremlipsumdolorsitametconsecteturadipiscingelitnullamtinciduntu") + storage.S3FileStore( + "loremlipsumdolorsitametconsecteturadipiscingelitnullamtinciduntu" + ) # Should not raise storage.S3FileStore("abc") # Should not raise - storage.S3FileStore("loremlipsumdolorsitametconsecteturadipiscingelitnullamtincidunt") + storage.S3FileStore( + "loremlipsumdolorsitametconsecteturadipiscingelitnullamtincidunt" + ) def test_make_key_path_raises_error_when_exceeds_max_length(self): s3_file_store = storage.S3FileStore("some-bucket") @@ -353,12 +410,21 @@ def test_make_key_path_raises_error_when_exceeds_max_length(self): [ ("", "file.txt", "2021/09/10/file.txt"), ("namespace", "file.txt", "namespace/2021/09/10/file.txt"), - ("namespace/sub-namespace", "file.txt", "namespace/sub-namespace/2021/09/10/file.txt"), + ( + "namespace/sub-namespace", + "file.txt", + "namespace/sub-namespace/2021/09/10/file.txt", + ), ], ) - def test_make_key_path_with_use_date_in_key_path(self, namespace, filepath, expected): + def test_make_key_path_with_use_date_in_key_path( + self, namespace, filepath, expected + ): s3_file_store = storage.S3FileStore("some-bucket", use_date_in_key_path=True) - assert s3_file_store.make_key_path(namespace=namespace, filepath=filepath) == expected + assert ( + s3_file_store.make_key_path(namespace=namespace, filepath=filepath) + == expected + ) @time_machine.travel("2021-09-10", tick=False) @pytest.mark.parametrize( @@ -369,12 +435,21 @@ def test_make_key_path_with_use_date_in_key_path(self, namespace, filepath, expe ("namespace", "file.txt", "namespace/file.txt"), ("namespace/", "file.txt", "namespace/file.txt"), ("namespace/sub-namespace", "file.txt", "namespace/sub-namespace/file.txt"), - ("namespace/sub-namespace/", "file.txt", "namespace/sub-namespace/file.txt"), + ( + "namespace/sub-namespace/", + "file.txt", + "namespace/sub-namespace/file.txt", + ), ], ) - def test_make_key_path_without_use_date_in_key_path(self, namespace, filepath, expected): + def test_make_key_path_without_use_date_in_key_path( + self, namespace, filepath, expected + ): s3_file_store = storage.S3FileStore("some-bucket", use_date_in_key_path=False) - assert s3_file_store.make_key_path(namespace=namespace, filepath=filepath) == expected + assert ( + s3_file_store.make_key_path(namespace=namespace, filepath=filepath) + == expected + ) @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @mock.patch.object(storage, "open", new_callable=mock.mock_open) @@ -443,7 +518,9 @@ def test_fetch_file_contents_using_s3_select_and_expect_output_in_csv_format(sel }, "CompressionType": "NONE", }, - OutputSerialization={"CSV": {"FieldDelimiter": ",", "RecordDelimiter": "\n"}}, + OutputSerialization={ + "CSV": {"FieldDelimiter": ",", "RecordDelimiter": "\n"} + }, ) @mock.patch.object(storage.S3FileStore, "_get_boto_object_for_key") @@ -538,7 +615,9 @@ def test_fetch_file_contents_using_s3_select_with_parquet_as_input(self): OutputSerialization={"JSON": {"RecordDelimiter": "\n"}}, ) - def test_fetch_file_contents_using_s3_select_with_parquet_fails_with_scan_range(self): + def test_fetch_file_contents_using_s3_select_with_parquet_fails_with_scan_range( + self + ): store = storage.S3FileStore("some-bucket") # Moto doesn't support faking a response from `select_object_content` that's why @@ -555,7 +634,6 @@ def test_fetch_file_contents_using_s3_select_with_parquet_fails_with_scan_range( } with pytest.raises(ValueError) as error: - list( store.fetch_file_contents_using_s3_select( key_path="some_file.parquet", @@ -566,10 +644,15 @@ def test_fetch_file_contents_using_s3_select_with_parquet_fails_with_scan_range( ) ) - assert str(error.value) == "The scan_range parameter is not supported for parquet files" + assert ( + str(error.value) + == "The scan_range parameter is not supported for parquet files" + ) @pytest.mark.parametrize("expected_error_code", [400, 401, 403, 500]) - def test_fetch_file_contents_using_s3_select_raises_errors(self, expected_error_code): + def test_fetch_file_contents_using_s3_select_raises_errors( + self, expected_error_code + ): store = storage.S3FileStore("some-bucket") boto_client = mock.Mock() @@ -581,7 +664,8 @@ def test_fetch_file_contents_using_s3_select_raises_errors(self, expected_error_ } with pytest.raises( - storage.S3SelectUnexpectedResponse, match="Received invalid response from S3 Select" + storage.S3SelectUnexpectedResponse, + match="Received invalid response from S3 Select", ): file_contents = list( store.fetch_file_contents_using_s3_select( @@ -627,9 +711,13 @@ def test_versioned_store_and_fetch(self): contents = self.store.fetch_file_contents(path) assert contents == b"last_contents" - @mock.patch.object(builtins, "open", mock.mock_open(read_data=b"test_store_filepath")) + @mock.patch.object( + builtins, "open", mock.mock_open(read_data=b"test_store_filepath") + ) def test_store_filepath(self, *mocks): - bucket_name, path = self.store.store_filepath(namespace="x", filepath="test.pdf") + bucket_name, path = self.store.store_filepath( + namespace="x", filepath="test.pdf" + ) assert bucket_name == "bucket" assert path == "x/test.pdf" @@ -637,7 +725,9 @@ def test_store_filepath(self, *mocks): assert contents == b"test_store_filepath" @mock.patch.object( - builtins, "open", mock.mock_open(read_data=b"test_store_filepath_with_dest_filepath") + builtins, + "open", + mock.mock_open(read_data=b"test_store_filepath_with_dest_filepath"), ) def test_store_filepath_with_dest_filepath(self, *mocks): bucket_name, path = self.store.store_filepath( @@ -658,7 +748,9 @@ def test_fetch_nonexistent(self): def test_list_s3_keys_page(self): filenames = [f"file_{i:04}.pdf" for i in range(105)] for filename in filenames: - self.store.store_file(namespace="", filename=filename, contents=f"{filename} content") + self.store.store_file( + namespace="", filename=filename, contents=f"{filename} content" + ) expected = [storage.S3Object("bucket", filename) for filename in filenames] @@ -669,23 +761,37 @@ def test_list_s3_keys_page(self): assert not next_token def test_list_files(self): - self.store.store_file(namespace="x", filename="test.pdf", contents=b"test_list_files_1") - self.store.store_file(namespace="x", filename="test2.pdf", contents=b"test_list_files_2") - self.store.store_file(namespace="y", filename="test3.pdf", contents=b"test_list_files_3") + self.store.store_file( + namespace="x", filename="test.pdf", contents=b"test_list_files_1" + ) + self.store.store_file( + namespace="x", filename="test2.pdf", contents=b"test_list_files_2" + ) + self.store.store_file( + namespace="y", filename="test3.pdf", contents=b"test_list_files_3" + ) listings = self.store.list_files(namespace="x") assert list(listings) == ["x/test.pdf", "x/test2.pdf"] def test_list_files_without_namespace(self): - self.store.store_file(namespace="x", filename="test.pdf", contents=b"test_list_files_1") - self.store.store_file(namespace="x", filename="test2.pdf", contents=b"test_list_files_2") - self.store.store_file(namespace="y", filename="test3.pdf", contents=b"test_list_files_3") + self.store.store_file( + namespace="x", filename="test.pdf", contents=b"test_list_files_1" + ) + self.store.store_file( + namespace="x", filename="test2.pdf", contents=b"test_list_files_2" + ) + self.store.store_file( + namespace="y", filename="test3.pdf", contents=b"test_list_files_3" + ) listings = self.store.list_files() assert list(listings) == ["x/test.pdf", "x/test2.pdf", "y/test3.pdf"] def test_download_file(self): - self.store.store_file(namespace="mem", filename="test.pdf", contents=b"test_download_file") + self.store.store_file( + namespace="mem", filename="test.pdf", contents=b"test_download_file" + ) file = self.store.download_file("mem/test.pdf") assert file.name == "/tmp/bucket/mem/test.pdf" @@ -713,7 +819,9 @@ def test_store_and_fetch(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=False) - __, path = store.store_file(namespace="x", filename="test.pdf", contents="hello") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents="hello" + ) contents = store.fetch_file_contents(path) assert contents == b"hello" @@ -749,7 +857,9 @@ def test_store_and_fetch_datepath(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=True) - __, path = store.store_file(namespace="x", filename="test.pdf", contents="hello") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents="hello" + ) contents = store.fetch_file_contents(path) assert contents == b"hello" @@ -762,7 +872,9 @@ def test_get_last_modified(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=True) - __, path = store.store_file(namespace="x", filename="test.pdf", contents="hello") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents="hello" + ) last_modified = store.get_last_modified(path) assert last_modified is not None @@ -800,7 +912,9 @@ def test_exists(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=False) - __, path = store.store_file(namespace="x", filename="test.pdf", contents="hello") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents="hello" + ) assert store.exists(path) is True @@ -808,7 +922,9 @@ def test_exists_datepath(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=True) - __, path = store.store_file(namespace="x", filename="test.pdf", contents="hello") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents="hello" + ) assert store.exists(path) is True @@ -829,9 +945,15 @@ def test_list_files(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=False) - __, path = store.store_file(namespace="x", filename="test.pdf", contents="hello") - __, path = store.store_file(namespace="x", filename="test2.pdf", contents="goodbye") - __, path = store.store_file(namespace="y", filename="test3.pdf", contents="goodbye") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents="hello" + ) + __, path = store.store_file( + namespace="x", filename="test2.pdf", contents="goodbye" + ) + __, path = store.store_file( + namespace="y", filename="test3.pdf", contents="goodbye" + ) listings = store.list_files(namespace="x") assert sorted(list(listings)) == ["x/test.pdf", "x/test2.pdf"] @@ -839,7 +961,9 @@ def test_list_files(self): def test_download_file(self): with tempfile.TemporaryDirectory() as tdir: store = storage.LocalFileStore("bucket", tdir, use_date_in_key_path=False) - __, path = store.store_file(namespace="x", filename="test.pdf", contents=b"hello") + __, path = store.store_file( + namespace="x", filename="test.pdf", contents=b"hello" + ) file = store.download_file("x/test.pdf") @@ -859,10 +983,11 @@ def test_fetch_url_with_version(self): @mock.patch.object(storage.LocalFileStore, "_filepath_for_key_path") def test_fetch_csv_file_contents_using_s3_select(self, mock__filepath_for_key_path): - store = storage.LocalFileStore("my_bucket") mock_csv_data = "Name,Age\nAlice,25\nBob,30\nCharlie,35\n" - with tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".csv") as tmp_csv_file: + with tempfile.NamedTemporaryFile( + delete=False, mode="w", suffix=".csv" + ) as tmp_csv_file: tmp_csv_file.write(mock_csv_data) tmp_csv_file_path = tmp_csv_file.name @@ -893,10 +1018,11 @@ def test_fetch_csv_file_contents_using_s3_select(self, mock__filepath_for_key_pa def test_fetch_csv_file_contents_using_s3_select_and_where_statement( self, mock__filepath_for_key_path ): - store = storage.LocalFileStore("my_bucket") mock_csv_data = "Name,Age\nAlice,25\nBob,30\nCharlie,35\n" - with tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".csv") as tmp_csv_file: + with tempfile.NamedTemporaryFile( + delete=False, mode="w", suffix=".csv" + ) as tmp_csv_file: tmp_csv_file.write(mock_csv_data) tmp_csv_file_path = tmp_csv_file.name @@ -920,8 +1046,9 @@ def test_fetch_csv_file_contents_using_s3_select_and_where_statement( assert results == expected_results @mock.patch.object(storage.LocalFileStore, "_filepath_for_key_path") - def test_fetch_parquet_file_contents_using_s3_select(self, mock__filepath_for_key_path): - + def test_fetch_parquet_file_contents_using_s3_select( + self, mock__filepath_for_key_path + ): store = storage.LocalFileStore("my_bucket") mock_data = {"Name": ["Alice", "Bob", "Charlie"], "Age": [25, 30, 35]} df = pd.DataFrame(mock_data) @@ -956,7 +1083,6 @@ def test_fetch_parquet_file_contents_using_s3_select(self, mock__filepath_for_ke assert results == expected_results def test_fetch_nonexistent_file_with_s3_select(self): - input_serializer = s3_select.CSVInputSerializer(s3_select.FileHeaderInfo.USE) output_serializer = s3_select.JSONOutputSerializer() store = storage.LocalFileStore("my_bucket") @@ -972,7 +1098,6 @@ def test_fetch_nonexistent_file_with_s3_select(self): ) def test_fetch_file_with_s3_select_scan_range_raises_error(self): - input_serializer = s3_select.CSVInputSerializer(s3_select.FileHeaderInfo.USE) output_serializer = s3_select.JSONOutputSerializer() store = storage.LocalFileStore("my_bucket") @@ -992,7 +1117,6 @@ def test_fetch_file_with_s3_select_scan_range_raises_error(self): def test_json_output_unsupported_record_separator_raises_exception( self, mock__filepath_for_key_path ): - store = storage.LocalFileStore("my_bucket") mock_data = {"Name": ["Alice", "Bob", "Charlie"], "Age": [25, 30, 35]} df = pd.DataFrame(mock_data) @@ -1022,7 +1146,9 @@ def test_json_output_unsupported_record_separator_raises_exception( def test_output_csv_with_serializer_quoting_always(self): store = storage.LocalFileStore("my_bucket") - serializer = s3_select.CSVOutputSerializer(QuoteFields=s3_select.QuoteFields.ALWAYS) + serializer = s3_select.CSVOutputSerializer( + QuoteFields=s3_select.QuoteFields.ALWAYS + ) result = store.output_csv_with_serializer( df=self.sample_dataframe, output_serializer=serializer ) @@ -1030,12 +1156,13 @@ def test_output_csv_with_serializer_quoting_always(self): assert result == expected def test_output_csv_with_serializer_quoting_as_needed(self): - sample_dataframe = pd.DataFrame( {"Name": ["Ali,ce", "Bob", "Charlie"], "Age": [25, 30, 35]} ) store = storage.LocalFileStore("my_bucket") - serializer = s3_select.CSVOutputSerializer(QuoteFields=s3_select.QuoteFields.ASNEEDED) + serializer = s3_select.CSVOutputSerializer( + QuoteFields=s3_select.QuoteFields.ASNEEDED + ) result = store.output_csv_with_serializer( df=sample_dataframe, output_serializer=serializer ) @@ -1072,7 +1199,9 @@ def test_output_csv_with_serializer_custom_record_delimiter(self): def test_read_csv_with_serializer(self): store = storage.LocalFileStore("my_bucket") - with tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".csv") as tmp_csv_file: + with tempfile.NamedTemporaryFile( + delete=False, mode="w", suffix=".csv" + ) as tmp_csv_file: tmp_csv_file.write(self.csv_data) tmp_csv_file_path = tmp_csv_file.name input_serializer = s3_select.CSVInputSerializer(s3_select.FileHeaderInfo.USE) @@ -1083,7 +1212,6 @@ def test_read_csv_with_serializer(self): assert isinstance(result, pd.DataFrame) def test_query_dataframe_with_sql(self): - data = { "string_column": ["A", "B", "C"], "array_column": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], @@ -1101,7 +1229,6 @@ def test_query_dataframe_with_sql(self): assert result_df["array_column"][0] == [1, 2, 3] def test_query_dataframe_with_sql_with_capitalised_object_in_query(self): - dummy_df = pd.DataFrame(self.sample_dataframe) store = storage.LocalFileStore("my_bucket") diff --git a/tests/test_localtime.py b/tests/test_localtime.py index 184f3ec..1c8258d 100644 --- a/tests/test_localtime.py +++ b/tests/test_localtime.py @@ -1,9 +1,9 @@ import datetime import decimal +import zoneinfo import pytest import time_machine -import zoneinfo from dateutil import relativedelta from django.conf import settings from django.test import override_settings @@ -26,7 +26,9 @@ def test_seconds_in_future(self): localtime.datetime_.datetime(2020, 1, 1, 12, 0, 1, tzinfo=localtime.UTC) ) assert localtime.seconds_in_the_future(1.5) == localtime.as_localtime( - localtime.datetime_.datetime(2020, 1, 1, 12, 0, 1, 500000, tzinfo=localtime.UTC) + localtime.datetime_.datetime( + 2020, 1, 1, 12, 0, 1, 500000, tzinfo=localtime.UTC + ) ) @@ -34,22 +36,30 @@ class TestSecondsInThePast: def test_seconds_in_past(self): with time_machine.travel("2020-01-01 12:00:00.000", tick=False): assert localtime.seconds_in_the_past(1) == localtime.as_localtime( - localtime.datetime_.datetime(2020, 1, 1, 11, 59, 59, tzinfo=localtime.UTC) + localtime.datetime_.datetime( + 2020, 1, 1, 11, 59, 59, tzinfo=localtime.UTC + ) ) assert localtime.seconds_in_the_past(1.5) == localtime.as_localtime( - localtime.datetime_.datetime(2020, 1, 1, 11, 59, 58, 500000, tzinfo=localtime.UTC) + localtime.datetime_.datetime( + 2020, 1, 1, 11, 59, 58, 500000, tzinfo=localtime.UTC + ) ) class TestDate: def test_date_calculation_near_midnight_during_bst(self): - near_midnight_in_utc = datetime.datetime(2016, 6, 1, 23, 50, 0, tzinfo=localtime.UTC) + near_midnight_in_utc = datetime.datetime( + 2016, 6, 1, 23, 50, 0, tzinfo=localtime.UTC + ) assert localtime.date(near_midnight_in_utc) == ( near_midnight_in_utc.date() + datetime.timedelta(days=1) ) def test_date_calculation_near_midnight_outside_of_bst(self): - near_midnight_in_utc = datetime.datetime(2016, 11, 1, 23, 50, 0, tzinfo=localtime.UTC) + near_midnight_in_utc = datetime.datetime( + 2016, 11, 1, 23, 50, 0, tzinfo=localtime.UTC + ) assert localtime.date(near_midnight_in_utc) == near_midnight_in_utc.date() @pytest.mark.parametrize("tz", (zoneinfo.ZoneInfo("Etc/GMT-10"), localtime.UTC)) @@ -62,7 +72,9 @@ def test_datetime_not_supplied(self): Check that we do not fallback to today if a datetime is not passed to the function - we have localtime.today for that. """ - with pytest.raises(TypeError, match="You must supply a datetime to localtime.date"): + with pytest.raises( + TypeError, match="You must supply a datetime to localtime.date" + ): localtime.date(None) @@ -115,7 +127,9 @@ def test_convert_date_to_midnight_and_back(self): def test_midnight_in_different_timezone(self): aus_time = zoneinfo.ZoneInfo("Etc/GMT-10") - with time_machine.travel(datetime.datetime(2020, 2, 2, 1, tzinfo=aus_time), tick=False): + with time_machine.travel( + datetime.datetime(2020, 2, 2, 1, tzinfo=aus_time), tick=False + ): result = localtime.midnight(tz=aus_time) assert result == datetime.datetime(2020, 2, 2, 0, 0, tzinfo=aus_time) @@ -128,19 +142,39 @@ def test_doesnt_change_date_of_already_midnight_datetime(self): midnight = datetime.datetime(2020, 6, 1, 23, tzinfo=datetime.timezone.utc) # We'll assert the same thing three ways for clarity: - assert localtime.midnight(midnight).date() == localtime.as_localtime(midnight).date() + assert ( + localtime.midnight(midnight).date() + == localtime.as_localtime(midnight).date() + ) assert localtime.midnight(midnight).date() == datetime.date(2020, 6, 2) assert localtime.midnight(midnight) == midnight - @override_settings(TIME_ZONE="Australia/Sydney") # set the django default/current timezone + @override_settings( + TIME_ZONE="Australia/Sydney" + ) # set the django default/current timezone @pytest.mark.parametrize( "naive_datetime,expected_midnight", [ - (datetime.datetime(2021, 6, 17, 18, 0, 0), datetime.datetime(2021, 6, 17, 0, 0, 0)), - (datetime.datetime(2021, 6, 17, 23, 30, 0), datetime.datetime(2021, 6, 17, 0, 0, 0)), - (datetime.datetime(2021, 6, 18, 0, 0, 0), datetime.datetime(2021, 6, 18, 0, 0, 0)), - (datetime.datetime(2021, 6, 18, 0, 30, 0), datetime.datetime(2021, 6, 18, 0, 0, 0)), - (datetime.datetime(2021, 6, 18, 6, 0, 0), datetime.datetime(2021, 6, 18, 0, 0, 0)), + ( + datetime.datetime(2021, 6, 17, 18, 0, 0), + datetime.datetime(2021, 6, 17, 0, 0, 0), + ), + ( + datetime.datetime(2021, 6, 17, 23, 30, 0), + datetime.datetime(2021, 6, 17, 0, 0, 0), + ), + ( + datetime.datetime(2021, 6, 18, 0, 0, 0), + datetime.datetime(2021, 6, 18, 0, 0, 0), + ), + ( + datetime.datetime(2021, 6, 18, 0, 30, 0), + datetime.datetime(2021, 6, 18, 0, 0, 0), + ), + ( + datetime.datetime(2021, 6, 18, 6, 0, 0), + datetime.datetime(2021, 6, 18, 0, 0, 0), + ), ], ) def test_localtime_midnight_calculation_for_naive_datetime_and_no_timezone( @@ -163,7 +197,9 @@ def test_localtime_midnight_calculation_for_naive_datetime_and_no_timezone( assert actual_midnight == expected_midnight assert str(actual_midnight.tzinfo) == "Australia/Sydney" - @override_settings(TIME_ZONE="Australia/Sydney") # set the django default/current timezone + @override_settings( + TIME_ZONE="Australia/Sydney" + ) # set the django default/current timezone @pytest.mark.parametrize( "naive_datetime,specified_timezone,expected_midnight", [ @@ -230,7 +266,9 @@ def test_localtime_midnight_calculation_for_naive_datetime_and_specified_timezon """ specified_timezone_obj = zoneinfo.ZoneInfo(specified_timezone) # attach the specified timezone to the expected midnight - expected_midnight = timezone.make_aware(expected_midnight, timezone=specified_timezone_obj) + expected_midnight = timezone.make_aware( + expected_midnight, timezone=specified_timezone_obj + ) actual_midnight = localtime.midnight(naive_datetime, tz=specified_timezone_obj) @@ -256,7 +294,9 @@ def test_midday_calculation_without_date_uses_today(self): def test_midday_in_different_timezone(self): aus_time = zoneinfo.ZoneInfo("Etc/GMT-10") - with time_machine.travel(datetime.datetime(2020, 2, 2, 1, tzinfo=aus_time), tick=False): + with time_machine.travel( + datetime.datetime(2020, 2, 2, 1, tzinfo=aus_time), tick=False + ): result = localtime.midday(tz=aus_time) assert result == datetime.datetime(2020, 2, 2, 12, 0, tzinfo=aus_time) @@ -521,7 +561,10 @@ class TestStartOfMonth: localtime.datetime(2016, 12, 5, 11, 34, 59), localtime.datetime(2016, 12, 1, 0, 0, 0), ), - (localtime.datetime(2017, 3, 31, 11, 29, 59), localtime.datetime(2017, 3, 1, 0, 0, 0)), + ( + localtime.datetime(2017, 3, 31, 11, 29, 59), + localtime.datetime(2017, 3, 1, 0, 0, 0), + ), ], ) def test_start_of_month(self, dt, result): @@ -532,8 +575,14 @@ class TestEndOfMonth: @pytest.mark.parametrize( ("dt", "result"), [ - (localtime.datetime(2016, 12, 5, 11, 34, 59), localtime.datetime(2017, 1, 1, 0, 0, 0)), - (localtime.datetime(2017, 3, 31, 11, 29, 59), localtime.datetime(2017, 4, 1, 0, 0, 0)), + ( + localtime.datetime(2016, 12, 5, 11, 34, 59), + localtime.datetime(2017, 1, 1, 0, 0, 0), + ), + ( + localtime.datetime(2017, 3, 31, 11, 29, 59), + localtime.datetime(2017, 4, 1, 0, 0, 0), + ), ], ) def test_end_of_month(self, dt, result): @@ -589,7 +638,9 @@ def test_dst_end(self): def test_default_in_different_timezone(self): aus_time = zoneinfo.ZoneInfo("Etc/GMT-10") - with time_machine.travel(datetime.datetime(2020, 2, 2, 1, tzinfo=aus_time), tick=False): + with time_machine.travel( + datetime.datetime(2020, 2, 2, 1, tzinfo=aus_time), tick=False + ): result = localtime.next_midnight(tz=aus_time) assert result == datetime.datetime(2020, 2, 3, 0, 0, tzinfo=aus_time) @@ -616,10 +667,12 @@ def test_dst_end_datetime(self, dt, expected): class TestDaysInThePast: def test_is_sane(self): - assert localtime.days_in_the_past(2) == datetime.date.today() - datetime.timedelta(days=2) - assert localtime.days_in_the_past(-20) == datetime.date.today() + datetime.timedelta( - days=20 - ) + assert localtime.days_in_the_past( + 2 + ) == datetime.date.today() - datetime.timedelta(days=2) + assert localtime.days_in_the_past( + -20 + ) == datetime.date.today() + datetime.timedelta(days=20) assert localtime.days_in_the_past(0) == datetime.date.today() assert localtime.days_in_the_past(1) == localtime.yesterday() assert localtime.days_in_the_past(-1) == localtime.tomorrow() @@ -627,12 +680,12 @@ def test_is_sane(self): class TestDaysInTheFuture: def test_is_sane(self): - assert localtime.days_in_the_future(2) == datetime.date.today() + datetime.timedelta( - days=2 - ) - assert localtime.days_in_the_future(-20) == datetime.date.today() - datetime.timedelta( - days=20 - ) + assert localtime.days_in_the_future( + 2 + ) == datetime.date.today() + datetime.timedelta(days=2) + assert localtime.days_in_the_future( + -20 + ) == datetime.date.today() - datetime.timedelta(days=20) assert localtime.days_in_the_future(0) == datetime.date.today() assert localtime.days_in_the_future(1) == localtime.tomorrow() assert localtime.days_in_the_future(-1) == localtime.yesterday() @@ -644,7 +697,12 @@ class TestLatestDateForDay: ( ("2017-01-01", "2018-12-31", 9, "2018-12-09"), # Result in last month. ("2017-01-01", "2018-12-08", 9, "2018-11-09"), # Result in previous month. - ("2017-01-01", "2017-03-30", 31, "2017-01-31"), # Result affected by short month. + ( + "2017-01-01", + "2017-03-30", + 31, + "2017-01-31", + ), # Result affected by short month. ("2017-01-12", "2017-01-30", 12, "2017-01-12"), # Result same as from date. ("2017-01-12", "2017-01-30", 30, "2017-01-30"), # Result same as to date. ("2017-01-12", "2017-02-10", 11, None), # Result not in range. @@ -653,7 +711,9 @@ class TestLatestDateForDay: ("2017-01-01", "2018-12-31", 32, ValueError), # Day too high. ), ) - def test_latest_date_for_day(self, start_date, end_date, day_of_month, expected_result): + def test_latest_date_for_day( + self, start_date, end_date, day_of_month, expected_result + ): kwargs = dict( start_date=factories.date(start_date), end_date=factories.date(end_date), @@ -736,7 +796,9 @@ def test_returns_correct_results_for_dates( now = factories.local.dt(now_str) supplied_date = factories.date(supplied_date_str) with time_machine.travel(now, tick=False): - assert localtime.is_within_the_last_week(supplied_date) == is_within_last_year + assert ( + localtime.is_within_the_last_week(supplied_date) == is_within_last_year + ) class TestIsDST: @@ -747,21 +809,53 @@ class TestIsDST: (datetime.datetime(2019, 1, 1), zoneinfo.ZoneInfo("Europe/London"), False), (datetime.datetime(2019, 6, 1), zoneinfo.ZoneInfo("Europe/London"), True), # Test London boundaries - (datetime.datetime(2017, 3, 26, 0, 0), zoneinfo.ZoneInfo("Europe/London"), False), - (datetime.datetime(2017, 3, 26, 2, 0), zoneinfo.ZoneInfo("Europe/London"), True), - (datetime.datetime(2017, 10, 29, 0, 0), zoneinfo.ZoneInfo("Europe/London"), True), - (datetime.datetime(2017, 10, 29, 2, 0), zoneinfo.ZoneInfo("Europe/London"), False), + ( + datetime.datetime(2017, 3, 26, 0, 0), + zoneinfo.ZoneInfo("Europe/London"), + False, + ), + ( + datetime.datetime(2017, 3, 26, 2, 0), + zoneinfo.ZoneInfo("Europe/London"), + True, + ), + ( + datetime.datetime(2017, 10, 29, 0, 0), + zoneinfo.ZoneInfo("Europe/London"), + True, + ), + ( + datetime.datetime(2017, 10, 29, 2, 0), + zoneinfo.ZoneInfo("Europe/London"), + False, + ), # UTC should never be DST (datetime.datetime(2019, 1, 1), zoneinfo.ZoneInfo("UTC"), False), (datetime.datetime(2019, 6, 1), zoneinfo.ZoneInfo("UTC"), False), (datetime.datetime(2019, 1, 1), datetime.timezone.utc, False), (datetime.datetime(2019, 6, 1), datetime.timezone.utc, False), # Test Eastern Australia timezone - (datetime.datetime(2019, 1, 1), zoneinfo.ZoneInfo("Australia/Sydney"), True), - (datetime.datetime(2019, 6, 1), zoneinfo.ZoneInfo("Australia/Sydney"), False), + ( + datetime.datetime(2019, 1, 1), + zoneinfo.ZoneInfo("Australia/Sydney"), + True, + ), + ( + datetime.datetime(2019, 6, 1), + zoneinfo.ZoneInfo("Australia/Sydney"), + False, + ), # Test Western Australia timezone (they don't have DST) - (datetime.datetime(2019, 1, 1), zoneinfo.ZoneInfo("Australia/Perth"), False), - (datetime.datetime(2019, 6, 1), zoneinfo.ZoneInfo("Australia/Perth"), False), + ( + datetime.datetime(2019, 1, 1), + zoneinfo.ZoneInfo("Australia/Perth"), + False, + ), + ( + datetime.datetime(2019, 6, 1), + zoneinfo.ZoneInfo("Australia/Perth"), + False, + ), ), ) def test_returns_correct_values(self, naive_datetime, tz, expected): @@ -826,7 +920,9 @@ class TestCombine: factories.date("1 Jun 2020"), factories.time("01:00"), "Europe/London", - datetime.datetime(2020, 6, 1, 1, 0).astimezone(zoneinfo.ZoneInfo("Europe/London")), + datetime.datetime(2020, 6, 1, 1, 0).astimezone( + zoneinfo.ZoneInfo("Europe/London") + ), ), ( factories.date("1 Jul 2021"), @@ -855,7 +951,9 @@ class TestNextDateWithDayOfMonth: ) def test_next_date_with_day_of_month(self, current_date, day_of_month, expected): assert ( - localtime.next_date_with_day_of_month(date=current_date, day_of_month=day_of_month) + localtime.next_date_with_day_of_month( + date=current_date, day_of_month=day_of_month + ) == expected ) @@ -879,7 +977,11 @@ def test_invalid_input(self): [(localtime.today(), localtime.today())], ), ( - [localtime.today(), localtime.yesterday(), localtime.days_in_the_future(2)], + [ + localtime.today(), + localtime.yesterday(), + localtime.days_in_the_future(2), + ], [ (localtime.yesterday(), localtime.today()), (localtime.days_in_the_future(2), localtime.days_in_the_future(2)), @@ -920,7 +1022,9 @@ def test_timestamp_british_summer_time_before_clocks_move_forward(self): """ # Before clocks move forward # 29th of March 2020 0:30am UTC = 0:30am Europe/London - timestamp = datetime.datetime(2020, 3, 29, 0, 30, tzinfo=localtime.UTC).timestamp() + timestamp = datetime.datetime( + 2020, 3, 29, 0, 30, tzinfo=localtime.UTC + ).timestamp() dt = localtime.datetime_from_epoch_timestamp(timestamp) @@ -941,7 +1045,9 @@ def test_timestamp_british_summer_time_after_clocks_move_forward(self): """ # After clocks move forward # 29th of March 2020 1:30am UTC = 2:30 am Europe/London - timestamp = datetime.datetime(2020, 3, 29, 1, 30, tzinfo=localtime.UTC).timestamp() + timestamp = datetime.datetime( + 2020, 3, 29, 1, 30, tzinfo=localtime.UTC + ).timestamp() dt = localtime.datetime_from_epoch_timestamp(timestamp) @@ -962,7 +1068,9 @@ def test_timestamp_british_summer_time_before_clocks_move_backward(self): """ # Before clocks move backwards # 25th of October 2020 0:30am UTC = 1:30am Europe/London - timestamp = datetime.datetime(2020, 10, 25, 0, 30, tzinfo=localtime.UTC).timestamp() + timestamp = datetime.datetime( + 2020, 10, 25, 0, 30, tzinfo=localtime.UTC + ).timestamp() dt = localtime.datetime_from_epoch_timestamp(timestamp) @@ -983,7 +1091,9 @@ def test_timestamp_british_summer_time_after_clocks_move_backward(self): """ # After clocks move backwards # 25th of October 2020 1:30am UTC = 1:30am Europe/London - timestamp = datetime.datetime(2020, 10, 25, 1, 30, tzinfo=localtime.UTC).timestamp() + timestamp = datetime.datetime( + 2020, 10, 25, 1, 30, tzinfo=localtime.UTC + ).timestamp() dt = localtime.datetime_from_epoch_timestamp(timestamp) @@ -1047,8 +1157,12 @@ class TestPeriodExceedsOneYear: ), ], ) - def test_period_exceeds_one_year(self, period_start_at, first_dt_exceeding_one_year): - assert localtime.period_exceeds_one_year(period_start_at, first_dt_exceeding_one_year) + def test_period_exceeds_one_year( + self, period_start_at, first_dt_exceeding_one_year + ): + assert localtime.period_exceeds_one_year( + period_start_at, first_dt_exceeding_one_year + ) assert not localtime.period_exceeds_one_year( period_start_at, first_dt_exceeding_one_year - relativedelta.relativedelta(microseconds=1), diff --git a/tests/test_numbers.py b/tests/test_numbers.py index f46bef1..149be5b 100644 --- a/tests/test_numbers.py +++ b/tests/test_numbers.py @@ -33,7 +33,6 @@ def test_quantise(number_to_round, base, rounding_strategy, expected_result): def test_truncate_decimal_places(): - assert numbers.truncate_decimal_places(D("123.45"), 1) == 123.4 assert numbers.truncate_decimal_places(D("123.456"), 1) == 123.4 assert numbers.truncate_decimal_places(D("123.4"), 2) == 123.40 diff --git a/tests/test_ranges.py b/tests/test_ranges.py index c55bca3..2aaafd7 100644 --- a/tests/test_ranges.py +++ b/tests/test_ranges.py @@ -249,7 +249,9 @@ def test_union_and_intersection_are_commutative( @given(valid_integer_range(), valid_integer_range()) -def test_union_and_intersection_are_idempotent(a: ranges.Range[Any], b: ranges.Range[Any]) -> None: +def test_union_and_intersection_are_idempotent( + a: ranges.Range[Any], b: ranges.Range[Any] +) -> None: union = a | b assume(union is not None) assert union is not None @@ -276,7 +278,9 @@ def test_range_difference_and_intersection_form_partition( # a contains b assert b_difference is None assert a_difference.is_disjoint(ranges.RangeSet([intersection])) - assert a_difference | ranges.RangeSet([intersection]) == ranges.RangeSet([a]) + assert a_difference | ranges.RangeSet( + [intersection] + ) == ranges.RangeSet([a]) else: assert a_difference.is_disjoint(intersection) assert a_difference | intersection == a @@ -286,7 +290,9 @@ def test_range_difference_and_intersection_form_partition( # b contains a assert a_difference is None assert b_difference.is_disjoint(ranges.RangeSet([intersection])) - assert b_difference | ranges.RangeSet([intersection]) == ranges.RangeSet([b]) + assert b_difference | ranges.RangeSet( + [intersection] + ) == ranges.RangeSet([b]) else: assert b_difference.is_disjoint(intersection) assert b_difference | intersection == b @@ -387,7 +393,9 @@ def test_finite_range(): (ranges.RangeSet([ranges.Range(1, 3), ranges.Range(0, 2)]), "{[0,3)}"), ], ) -def test_rangeset_construction(rangeset: ranges.RangeSet[Any], expected_string: str) -> None: +def test_rangeset_construction( + rangeset: ranges.RangeSet[Any], expected_string: str +) -> None: assert str(rangeset) == expected_string @@ -412,7 +420,11 @@ def test_rangeset_addition(a: ranges.Range[Any], b: ranges.Range[Any]) -> None: # Partial match (ranges.RangeSet([ranges.Range(0, 5)]), ranges.Range(1, 6), False), # Partial match - (ranges.RangeSet([ranges.Range(0, 2), ranges.Range(3, 7)]), ranges.Range(1, 6), False), + ( + ranges.RangeSet([ranges.Range(0, 2), ranges.Range(3, 7)]), + ranges.Range(1, 6), + False, + ), ], ) def test_rangeset_contains_range(rangeset, item, expected_result): @@ -574,7 +586,9 @@ class TestAnyOverlapping: ranges.Range(1, 3), ], [ - ranges.Range(0, 2, boundaries=ranges.RangeBoundaries.INCLUSIVE_INCLUSIVE), + ranges.Range( + 0, 2, boundaries=ranges.RangeBoundaries.INCLUSIVE_INCLUSIVE + ), ranges.Range(2, 4), ], ], @@ -591,7 +605,9 @@ def test_returns_true_if_and_ranges_overlap(self, ranges_): ], [ ranges.Range(0, 2), - ranges.Range(2, 4, boundaries=ranges.RangeBoundaries.EXCLUSIVE_INCLUSIVE), + ranges.Range( + 2, 4, boundaries=ranges.RangeBoundaries.EXCLUSIVE_INCLUSIVE + ), ], ], ) diff --git a/tests/test_settlement_periods.py b/tests/test_settlement_periods.py index 0f4851f..b937265 100644 --- a/tests/test_settlement_periods.py +++ b/tests/test_settlement_periods.py @@ -84,7 +84,10 @@ def test_convert_sp_and_date_to_utc_for_wholesale(sp, date, expected): Test the convert_sp_and_date_to_utc function within a wholesale context for days where british time is the same as GMT, where british time is BST, and change days """ - assert settlement_periods.convert_sp_and_date_to_utc(sp, date, is_wholesale=True) == expected + assert ( + settlement_periods.convert_sp_and_date_to_utc(sp, date, is_wholesale=True) + == expected + ) @pytest.mark.parametrize( diff --git a/tests/test_urls.py b/tests/test_urls.py index 83ae235..914c708 100644 --- a/tests/test_urls.py +++ b/tests/test_urls.py @@ -86,7 +86,9 @@ def test_handles_upload_dir_being_subdir_of_destination_dir(self): relative to that. """ dest_url = "ftp://some_server/gas/destination?upload=nested/sub/dir" - (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url(dest_url) + (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url( + dest_url + ) assert fs_url == "ftp://some_server/gas/destination" assert dest_path == "." @@ -100,7 +102,9 @@ def test_handles_destination_dir_being_subdir_of_upload_dir(self): relative to that. """ dest_url = "ftp://some_server/gas/destination?upload=.." - (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url(dest_url) + (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url( + dest_url + ) assert fs_url == "ftp://some_server/gas" assert dest_path == "destination" @@ -115,7 +119,9 @@ def test_handles_destination_dir_and_upload_dir_being_cousins(self): ancestor between the two dirs and both paths are given relative to that. """ dest_url = "ftp://some_server/gas/destination/dir?upload=../../upload/sub/dir" - (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url(dest_url) + (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url( + dest_url + ) assert fs_url == "ftp://some_server/gas" assert dest_path == "destination/dir" @@ -128,7 +134,9 @@ def test_handles_absolute_upload_dir(self): In that case the upload dir is resolved relative to the root of the FS URL. """ dest_url = "ftp://some_server/gas/destination?upload=/gas/upload" - (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url(dest_url) + (fs_url, dest_path, upload_path) = urls.parse_file_destination_from_url( + dest_url + ) assert fs_url == "ftp://some_server/gas" assert dest_path == "destination" diff --git a/xocto/localtime.py b/xocto/localtime.py index 9096e48..da24cde 100644 --- a/xocto/localtime.py +++ b/xocto/localtime.py @@ -3,9 +3,9 @@ import calendar import datetime as datetime_ import decimal +import zoneinfo from typing import Generator, Sequence, Tuple -import zoneinfo from dateutil import tz from dateutil.relativedelta import relativedelta from django.utils import timezone @@ -28,7 +28,9 @@ MIDNIGHT_TIME = datetime_.time(0, 0) -def as_localtime(dt: datetime_.datetime, tz: datetime_.tzinfo | None = None) -> datetime_.datetime: +def as_localtime( + dt: datetime_.datetime, tz: datetime_.tzinfo | None = None +) -> datetime_.datetime: """ Convert a tz aware datetime to localtime. @@ -173,7 +175,9 @@ def day_after(d: datetime_.date) -> datetime_.date: # Returning datetimes -def seconds_in_the_future(n: int, dt: datetime_.datetime | None = None) -> datetime_.datetime: +def seconds_in_the_future( + n: int, dt: datetime_.datetime | None = None +) -> datetime_.datetime: """ Return a datetime of the number of specifed seconds in the future. """ @@ -323,7 +327,9 @@ def date_boundaries( return midnight(_date, tz), next_midnight(_date, tz) -def month_boundaries(month: int, year: int) -> Tuple[datetime_.datetime, datetime_.datetime]: +def month_boundaries( + month: int, year: int +) -> Tuple[datetime_.datetime, datetime_.datetime]: """ Return the boundary datetimes of a given month. @@ -392,7 +398,9 @@ def within_date_range( def quantise( - dt: datetime_.datetime, timedelta: datetime_.timedelta, rounding: str = decimal.ROUND_HALF_EVEN + dt: datetime_.datetime, + timedelta: datetime_.timedelta, + rounding: str = decimal.ROUND_HALF_EVEN, ) -> datetime_.datetime: """ 'Round' a datetime to the nearest interval given by the `timedelta` argument. @@ -411,7 +419,9 @@ def quantise( quantised_dt_timestamp = numbers.quantise( dt_as_timestamp, timedelta_seconds, rounding=rounding ) - quantised_dt = datetime_.datetime.fromtimestamp(quantised_dt_timestamp, tz=dt.tzinfo) + quantised_dt = datetime_.datetime.fromtimestamp( + quantised_dt_timestamp, tz=dt.tzinfo + ) return as_localtime(quantised_dt) @@ -553,7 +563,9 @@ def latest_date_for_day( return None -def next_date_with_day_of_month(date: datetime_.date, day_of_month: int) -> datetime_.date: +def next_date_with_day_of_month( + date: datetime_.date, day_of_month: int +) -> datetime_.date: """ Given a starting `date`, return the next date with the specified `day_of_month`. @@ -619,7 +631,9 @@ def is_dst(local_time: datetime_.datetime) -> bool: return bool(local_time.dst()) -def is_localtime_midnight(dt: datetime_.datetime, tz: datetime_.tzinfo | None = None) -> bool: +def is_localtime_midnight( + dt: datetime_.datetime, tz: datetime_.tzinfo | None = None +) -> bool: """ Return whether the supplied datetime is at midnight (in the site's local time zone). @@ -635,7 +649,9 @@ def is_aligned_to_midnight( """ Return whether this range is aligned to localtime midnight. """ - return all([is_localtime_midnight(range.start, tz), is_localtime_midnight(range.end, tz)]) + return all( + [is_localtime_midnight(range.start, tz), is_localtime_midnight(range.end, tz)] + ) def consolidate_into_intervals( @@ -675,12 +691,17 @@ def consolidate_into_intervals( num_consecutive += 1 else: intervals.append( - (interval_start, interval_start + datetime_.timedelta(days=num_consecutive)) + ( + interval_start, + interval_start + datetime_.timedelta(days=num_consecutive), + ) ) interval_start = date num_consecutive = 0 - intervals.append((interval_start, interval_start + datetime_.timedelta(days=num_consecutive))) + intervals.append( + (interval_start, interval_start + datetime_.timedelta(days=num_consecutive)) + ) return intervals @@ -704,7 +725,9 @@ def translate_english_month_to_spanish(month: int) -> str: return month_name_lookup[month_name] -def period_exceeds_one_year(start_at: datetime_.datetime, end_at: datetime_.datetime) -> bool: +def period_exceeds_one_year( + start_at: datetime_.datetime, end_at: datetime_.datetime +) -> bool: """ Returns true if the passed period exceeds one year. diff --git a/xocto/numbers.py b/xocto/numbers.py index 139109a..a358798 100644 --- a/xocto/numbers.py +++ b/xocto/numbers.py @@ -7,7 +7,9 @@ from . import types -def quantise(number: int | float | str, base: int, rounding: str = decimal.ROUND_HALF_EVEN) -> int: +def quantise( + number: int | float | str, base: int, rounding: str = decimal.ROUND_HALF_EVEN +) -> int: """ Round a number to an arbitrary integer base. For example: >>> quantise(256, 5) @@ -79,7 +81,9 @@ def round_decimal_places( return value.quantize(decimal.Decimal(quantize_string), rounding=rounding) -def round_to_integer(value: decimal.Decimal, rounding: str = decimal.ROUND_HALF_UP) -> int: +def round_to_integer( + value: decimal.Decimal, rounding: str = decimal.ROUND_HALF_UP +) -> int: """ Round a decimal to the nearest integer, using a given rounding method. @@ -96,7 +100,10 @@ def round_to_integer(value: decimal.Decimal, rounding: str = decimal.ROUND_HALF_ def clip_to_range( - val: types.Comparable[T], *, minval: types.Comparable[T], maxval: types.Comparable[T] + val: types.Comparable[T], + *, + minval: types.Comparable[T], + maxval: types.Comparable[T], ) -> types.Comparable[T]: """ Clip the value to the min and max values given. diff --git a/xocto/pact_testing.py b/xocto/pact_testing.py index 95f97a5..7c1e1c6 100644 --- a/xocto/pact_testing.py +++ b/xocto/pact_testing.py @@ -23,7 +23,9 @@ def post( headers = {"Content-Type": "application/json"} if token: headers["Authorization"] = token - response = requests.post(url, data=json.dumps(data), headers=headers, verify=False) + response = requests.post( + url, data=json.dumps(data), headers=headers, verify=False + ) return response.json() diff --git a/xocto/ranges.py b/xocto/ranges.py index bbd606e..4d862b4 100644 --- a/xocto/ranges.py +++ b/xocto/ranges.py @@ -28,7 +28,9 @@ class RangeBoundaries(enum.Enum): INCLUSIVE_INCLUSIVE = "[]" @classmethod - def from_bounds(cls, left_exclusive: bool, right_exclusive: bool) -> "RangeBoundaries": + def from_bounds( + cls, left_exclusive: bool, right_exclusive: bool + ) -> "RangeBoundaries": """ Convenience method to get the relevant boundary type by specifiying the exclusivity of each end. @@ -247,10 +249,14 @@ def __lt__(self, other: "Range[T]") -> bool: return False else: # If one endpoint is None then that range is greater, otherwise compare them - return (other.end is None) or (self.end is not None and self.end < other.end) + return (other.end is None) or ( + self.end is not None and self.end < other.end + ) else: # If one endpoint is None then that range is lesser, otherwise compare them - return (self.start is None) or (other.start is not None and self.start < other.start) + return (self.start is None) or ( + other.start is not None and self.start < other.start + ) def __contains__(self, item: T) -> bool: """ @@ -286,13 +292,15 @@ def is_disjoint(self, other: "Range[T]") -> bool: """ if self.end is not None and other.start is not None: if not ( - self._is_inside_right_bound(other.start) and other._is_inside_left_bound(self.end) + self._is_inside_right_bound(other.start) + and other._is_inside_left_bound(self.end) ): return True if self.start is not None and other.end is not None: if not ( - self._is_inside_left_bound(other.end) and other._is_inside_right_bound(self.start) + self._is_inside_left_bound(other.end) + and other._is_inside_right_bound(self.start) ): return True @@ -320,7 +328,9 @@ def intersection(self, other: "Range[T]") -> Optional["Range[T]"]: end = range_r.end right_exclusive = range_r._is_right_exclusive - boundaries = RangeBoundaries.from_bounds(range_r._is_left_exclusive, right_exclusive) + boundaries = RangeBoundaries.from_bounds( + range_r._is_left_exclusive, right_exclusive + ) return Range(range_r.start, end, boundaries=boundaries) @@ -349,11 +359,15 @@ def union(self, other: "Range[T]") -> Optional["Range[T]"]: end = range_r.end right_exclusive = range_r._is_right_exclusive - boundaries = RangeBoundaries.from_bounds(range_l._is_left_exclusive, right_exclusive) + boundaries = RangeBoundaries.from_bounds( + range_l._is_left_exclusive, right_exclusive + ) return Range(range_l.start, end, boundaries=boundaries) - def difference(self, other: "Range[T]") -> Optional[Union["Range[T]", "RangeSet[T]"]]: + def difference( + self, other: "Range[T]" + ) -> Optional[Union["Range[T]", "RangeSet[T]"]]: """ Return a range or rangeset consisting of the bits of this range that do not intersect the other range (or None if this range is covered by the other range). @@ -385,7 +399,9 @@ def difference(self, other: "Range[T]") -> Optional[Union["Range[T]", "RangeSet[ boundaries = RangeBoundaries.from_bounds( other._is_right_inclusive, self._is_right_exclusive ) - upper_part: Optional["Range[T]"] = Range(other.end, self.end, boundaries=boundaries) + upper_part: Optional["Range[T]"] = Range( + other.end, self.end, boundaries=boundaries + ) else: upper_part = None @@ -668,7 +684,6 @@ def complement(self) -> RangeSet[T]: ) for preceeding_range, current_range in zip(self._ranges[:-1], self._ranges[1:]): - complement.append( Range( preceeding_range.end, @@ -749,7 +764,9 @@ def __init__(self, start: datetime.datetime, end: datetime.datetime): """ super().__init__(start, end, boundaries=RangeBoundaries.INCLUSIVE_EXCLUSIVE) - def intersection(self, other: Range[datetime.datetime]) -> Optional["FiniteDatetimeRange"]: + def intersection( + self, other: Range[datetime.datetime] + ) -> Optional["FiniteDatetimeRange"]: """ Intersections with finite ranges will always be finite. """ @@ -760,7 +777,9 @@ def intersection(self, other: Range[datetime.datetime]) -> Optional["FiniteDatet assert base_intersection.boundaries == RangeBoundaries.INCLUSIVE_EXCLUSIVE return FiniteDatetimeRange(base_intersection.start, base_intersection.end) - def __and__(self, other: Range[datetime.datetime]) -> Optional["FiniteDatetimeRange"]: + def __and__( + self, other: Range[datetime.datetime] + ) -> Optional["FiniteDatetimeRange"]: return self.intersection(other) @property diff --git a/xocto/settlement_periods.py b/xocto/settlement_periods.py index 553f67a..5233d28 100644 --- a/xocto/settlement_periods.py +++ b/xocto/settlement_periods.py @@ -125,7 +125,9 @@ def convert_local_to_sp_and_date( # Date of the settlement period in the time zone delivery_date = _get_delivery_date(half_hourly_time, timezone_str, is_wholesale) # First settlement period in the time zone - first_delivery_time = _get_first_delivery_time(delivery_date, timezone_str, is_wholesale) + first_delivery_time = _get_first_delivery_time( + delivery_date, timezone_str, is_wholesale + ) # Fetch settlement period delta = half_hourly_time - first_delivery_time settlement_period = ((int(delta.total_seconds()) // 60) + 30) // 30 diff --git a/xocto/storage/files.py b/xocto/storage/files.py index 0ea3035..e9eec59 100644 --- a/xocto/storage/files.py +++ b/xocto/storage/files.py @@ -71,7 +71,9 @@ def convert_xlsx_to_csv( workbook = openpyxl.load_workbook(xlsx_filepath, data_only=True, read_only=True) sheet = workbook.active - csv_file, wr = _get_csv_file_and_writer(csv_filepath, encoding, errors, quoting, delimiter) + csv_file, wr = _get_csv_file_and_writer( + csv_filepath, encoding, errors, quoting, delimiter + ) for row in sheet.rows: wr.writerow([cell.value for cell in row]) @@ -107,7 +109,9 @@ def convert_xls_to_csv( workbook = xlrd.open_workbook(xls_filepath) sheet = workbook.sheet_by_index(0) - csv_file, wr = _get_csv_file_and_writer(csv_filepath, encoding, errors, quoting, delimiter) + csv_file, wr = _get_csv_file_and_writer( + csv_filepath, encoding, errors, quoting, delimiter + ) for rownum in range(sheet.nrows): row = sheet.row(rownum) values = [] @@ -138,7 +142,9 @@ def _get_csv_file_and_writer( delimiter = "," if csv_filepath: - csv_file: IO[str] = open(csv_filepath, mode="w+", encoding=encoding, errors=errors) + csv_file: IO[str] = open( + csv_filepath, mode="w+", encoding=encoding, errors=errors + ) else: # `error' argument added in 3.8 csv_file = tempfile.NamedTemporaryFile(mode="w+", encoding=encoding) diff --git a/xocto/storage/storage.py b/xocto/storage/storage.py index cfeedd0..b7a5f41 100644 --- a/xocto/storage/storage.py +++ b/xocto/storage/storage.py @@ -220,7 +220,11 @@ def store_versioned_file( @abc.abstractmethod def store_filepath( - self, namespace: str, filepath: str, overwrite: bool = False, dest_filepath: str = "" + self, + namespace: str, + filepath: str, + overwrite: bool = False, + dest_filepath: str = "", ) -> tuple[str, str]: raise NotImplementedError() @@ -286,7 +290,10 @@ def get_key_or_store_file( return (self.bucket_name, key_path), False self.store_file( - namespace=namespace, filename=filepath, contents=contents, content_type=content_type + namespace=namespace, + filename=filepath, + contents=contents, + content_type=content_type, ) return (self.bucket_name, key_path), True @@ -299,7 +306,9 @@ def fetch_file(self, key_path: str, version_id: str | None = None) -> StreamingB raise NotImplementedError() @abc.abstractmethod - def fetch_file_contents(self, key_path: str, version_id: str | None = None) -> bytes: + def fetch_file_contents( + self, key_path: str, version_id: str | None = None + ) -> bytes: raise NotImplementedError() def fetch_text_file( @@ -324,7 +333,9 @@ def fetch_url( raise NotImplementedError() @abc.abstractmethod - def generate_presigned_post(self, *, key_path: str, expires_in: int = 60) -> PreSignedPost: + def generate_presigned_post( + self, *, key_path: str, expires_in: int = 60 + ) -> PreSignedPost: raise NotImplementedError @abc.abstractmethod @@ -472,7 +483,10 @@ def store_file( boto_client = self._get_boto_client() boto_client.upload_fileobj( - Fileobj=file_obj, Bucket=self.bucket_name, Key=key_path, ExtraArgs=extra_args + Fileobj=file_obj, + Bucket=self.bucket_name, + Key=key_path, + ExtraArgs=extra_args, ) return self.bucket_name, key_path @@ -496,14 +510,21 @@ def store_versioned_file( boto_client = self._get_boto_client() boto_response = boto_client.put_object( - Body=file_obj, Bucket=self.bucket_name, Key=key_path, **extra_args # type: ignore[arg-type] + Body=file_obj, + Bucket=self.bucket_name, + Key=key_path, + **extra_args, # type: ignore[arg-type] ) version_id = boto_response["VersionId"] return self.bucket_name, key_path, version_id def store_filepath( - self, namespace: str, filepath: str, overwrite: bool = False, dest_filepath: str = "" + self, + namespace: str, + filepath: str, + overwrite: bool = False, + dest_filepath: str = "", ) -> tuple[str, str]: """ Store a file in S3 given its local filepath. @@ -533,13 +554,18 @@ def store_filepath( boto_client = self._get_boto_client() boto_client.upload_file( - Filename=filepath, Bucket=self.bucket_name, Key=key_path, ExtraArgs=extra_args + Filename=filepath, + Bucket=self.bucket_name, + Key=key_path, + ExtraArgs=extra_args, ) return self.bucket_name, key_path def get_key(self, key_path: str, version_id: str | None = None) -> S3Object: - return S3Object(bucket_name=self.bucket_name, key=key_path, version_id=version_id) + return S3Object( + bucket_name=self.bucket_name, key=key_path, version_id=version_id + ) def get_file_type(self, key_path: str) -> str: return self._get_boto_object_for_key(key=key_path).content_type @@ -548,15 +574,19 @@ def fetch_file(self, key_path: str, version_id: str | None = None) -> StreamingB boto_object = self._get_boto_object_for_key(key=key_path, version_id=version_id) return boto_object.get()["Body"] - def fetch_file_contents(self, key_path: str, version_id: str | None = None) -> bytes: + def fetch_file_contents( + self, key_path: str, version_id: str | None = None + ) -> bytes: return self.fetch_file(key_path, version_id).read() def fetch_file_contents_using_s3_select( self, key_path: str, raw_sql: str, - input_serializer: s3_select.CSVInputSerializer | s3_select.ParquetInputSerializer, - output_serializer: s3_select.CSVOutputSerializer | s3_select.JSONOutputSerializer, + input_serializer: s3_select.CSVInputSerializer + | s3_select.ParquetInputSerializer, + output_serializer: s3_select.CSVOutputSerializer + | s3_select.JSONOutputSerializer, compression_type: s3_select.CompressionType | None = None, scan_range: s3_select.ScanRange | None = None, chunk_size: int | None = None, @@ -580,7 +610,9 @@ def fetch_file_contents_using_s3_select( ) elif isinstance(input_serializer, s3_select.ParquetInputSerializer): if scan_range is not None: - raise ValueError("The scan_range parameter is not supported for parquet files") + raise ValueError( + "The scan_range parameter is not supported for parquet files" + ) serialization = s3_select.get_serializers_for_parquet_file( output_serializer=output_serializer ) @@ -636,7 +668,9 @@ def fetch_url( "get_object", Params=params, ExpiresIn=expires_in ) - def generate_presigned_post(self, *, key_path: str, expires_in: int = 60) -> PreSignedPost: + def generate_presigned_post( + self, *, key_path: str, expires_in: int = 60 + ) -> PreSignedPost: boto_client = self._get_boto_client() presigned_post = boto_client.generate_presigned_post( Bucket=self.bucket_name, Key=key_path, ExpiresIn=expires_in @@ -771,7 +805,9 @@ def _get_policy(self) -> str | None: def _get_boto_client(self) -> S3Client: return boto3.client( - "s3", region_name=settings.AWS_REGION, endpoint_url=settings.AWS_S3_ENDPOINT_URL + "s3", + region_name=settings.AWS_REGION, + endpoint_url=settings.AWS_S3_ENDPOINT_URL, ) def _get_boto_bucket(self) -> service_resource.Bucket: @@ -811,7 +847,9 @@ def _get_boto_object_for_key( # It'd cause the S3SubdirectoryFileStore to add the subdir to the path, which is # not safe since the key may have come from somewhere that already includes this. return self._get_boto_object( - s3_object=S3Object(bucket_name=self.bucket_name, key=key, version_id=version_id) + s3_object=S3Object( + bucket_name=self.bucket_name, key=key, version_id=version_id + ) ) def _select_object_content( @@ -820,7 +858,6 @@ def _select_object_content( boto_client: S3Client, select_object_content_parameters: dict[str, Any], ) -> Iterator[str]: - # Error codes reference: https://docs.aws.amazon.com/AmazonS3/latest/API/ErrorResponses.html#SelectObjectContentErrorCodeList invalid_response_statuses = [400, 401, 403, 500] @@ -833,7 +870,9 @@ def _select_object_content( error.response.get("Error", {}).get("HTTPStatusCode", None) in invalid_response_statuses ): - raise S3SelectUnexpectedResponse("Received invalid response from S3 Select") + raise S3SelectUnexpectedResponse( + "Received invalid response from S3 Select" + ) raise if response["ResponseMetadata"]["HTTPStatusCode"] in invalid_response_statuses: @@ -1082,7 +1121,11 @@ def _build_download_filepath(self, key_path: str) -> str: return os.path.join(self.storage_root, key_path) def store_filepath( - self, namespace: str, filepath: str, overwrite: bool = False, dest_filepath: str = "" + self, + namespace: str, + filepath: str, + overwrite: bool = False, + dest_filepath: str = "", ) -> tuple[str, str]: if not dest_filepath: dest_filepath = os.path.basename(filepath) @@ -1098,7 +1141,9 @@ def store_filepath( return self.bucket_name, store_filepath def get_key(self, key_path: str, version_id: str | None = None) -> S3Object: - return S3Object(bucket_name=self.bucket_name, key=key_path, version_id=version_id) + return S3Object( + bucket_name=self.bucket_name, key=key_path, version_id=version_id + ) def get_file_type(self, key_path: str) -> str: mime = magic.Magic(mime=True) @@ -1122,17 +1167,23 @@ def fetch_file(self, key_path: str, version_id: str | None = None) -> StreamingB raise KeyDoesNotExist(f"Key {key_path} was not found at {file_path}") with open(file_path, "rb") as f: raw_stream = io.BytesIO(f.read()) - return StreamingBody(raw_stream=raw_stream, content_length=files.size(raw_stream)) + return StreamingBody( + raw_stream=raw_stream, content_length=files.size(raw_stream) + ) - def fetch_file_contents(self, key_path: str, version_id: str | None = None) -> bytes: + def fetch_file_contents( + self, key_path: str, version_id: str | None = None + ) -> bytes: return self.fetch_file(key_path, version_id).read() def fetch_file_contents_using_s3_select( self, key_path: str, raw_sql: str, - input_serializer: s3_select.CSVInputSerializer | s3_select.ParquetInputSerializer, - output_serializer: s3_select.CSVOutputSerializer | s3_select.JSONOutputSerializer, + input_serializer: s3_select.CSVInputSerializer + | s3_select.ParquetInputSerializer, + output_serializer: s3_select.CSVOutputSerializer + | s3_select.JSONOutputSerializer, compression_type: s3_select.CompressionType | None = None, scan_range: s3_select.ScanRange | None = None, chunk_size: int | None = None, @@ -1173,7 +1224,9 @@ def fetch_file_contents_using_s3_select( raise NotImplementedError( "Only newline ('\n') is supported as the record delimiter for JSON output in localdev" ) - result = filtered_df.to_json(orient="records", lines=True, date_format="iso") + result = filtered_df.to_json( + orient="records", lines=True, date_format="iso" + ) elif isinstance(output_serializer, s3_select.CSVOutputSerializer): result = self.output_csv_with_serializer( df=filtered_df, @@ -1187,7 +1240,6 @@ def query_dataframe_with_sql( raw_sql: str, df: pd.DataFrame, ) -> pd.DataFrame: - # s3 select requires the from clause to use the identifier "s3object" # it is case insensitive however so people's queries may use different cases S3_OBJECT_QUERY_IDENTIFIER = "s3object" @@ -1209,7 +1261,6 @@ def read_csv_with_serializer( csv_input_serializer: s3_select.CSVInputSerializer, compression_type: s3_select.CompressionType | None = None, ) -> pd.DataFrame: - input_serializer = csv_input_serializer.to_dict() field_delimiter = input_serializer.get("FieldDelimiter", ",") @@ -1221,7 +1272,6 @@ def read_csv_with_serializer( header: int | None | str if "FileHeaderInfo" in input_serializer.keys(): - if input_serializer["FileHeaderInfo"] == "NONE": header = None elif input_serializer["FileHeaderInfo"] == "IGNORE": @@ -1262,7 +1312,6 @@ def output_csv_with_serializer( df: pd.DataFrame, output_serializer: s3_select.CSVOutputSerializer, ) -> str: - output_serializer_dict = output_serializer.to_dict() field_delimiter = output_serializer_dict.get("FieldDelimiter", ",") @@ -1286,7 +1335,9 @@ def output_csv_with_serializer( "quoting": quoting, } - kwargs = {key: value for key, value in default_kwargs.items() if value is not None} + kwargs = { + key: value for key, value in default_kwargs.items() if value is not None + } result = df.to_csv(index=False, **kwargs) @@ -1319,7 +1370,9 @@ def fetch_url( return f"{settings.MEDIA_URL}{self.bucket_name}/{url_path}" - def generate_presigned_post(self, *, key_path: str, expires_in: int = 60) -> PreSignedPost: + def generate_presigned_post( + self, *, key_path: str, expires_in: int = 60 + ) -> PreSignedPost: return PreSignedPost( # Resolves to a localdev/storage url url=reverse("fake-presigned-post-upload"), @@ -1380,11 +1433,15 @@ def get_last_modified(self, key_path: str) -> datetime.datetime: return datetime.datetime.fromtimestamp(file_stats.st_mtime) def copy(self, *, s3_object: S3Object, destination: str) -> S3Object: - shutil.copyfile(src=self._filepath("", s3_object.key), dst=self._filepath("", destination)) + shutil.copyfile( + src=self._filepath("", s3_object.key), dst=self._filepath("", destination) + ) return S3Object(bucket_name=self.bucket_name, key=destination) def rename(self, *, s3_object: S3Object, destination: str) -> S3Object: - os.rename(src=self._filepath("", s3_object.key), dst=self._filepath("", destination)) + os.rename( + src=self._filepath("", s3_object.key), dst=self._filepath("", destination) + ) return S3Object(bucket_name=self.bucket_name, key=destination) def delete(self, *, s3_object: S3Object) -> None: @@ -1427,9 +1484,13 @@ class LocalEmailStore(LocalFileStore): ] def __init__(self, bucket_name: str = "", *args: Any, **kwargs: Any) -> None: - super().__init__(bucket_name=bucket_name, storage_root=settings.EMAIL_STORAGE_ROOT) + super().__init__( + bucket_name=bucket_name, storage_root=settings.EMAIL_STORAGE_ROOT + ) - def fetch_file_contents(self, key_path: str, version_id: str | None = None) -> bytes: + def fetch_file_contents( + self, key_path: str, version_id: str | None = None + ) -> bytes: # Randomly select one of the fixture files key_path = random.choice(self.email_keys) return super().fetch_file_contents(key_path, version_id) @@ -1513,24 +1574,39 @@ def store_versioned_file( content_type: str = "", ) -> tuple[str, str, str]: version = str(uuid.uuid4()) - self.versioned_buffers[self.bucket_name][key_path][version] = _to_bytes(contents=contents) + self.versioned_buffers[self.bucket_name][key_path][version] = _to_bytes( + contents=contents + ) self.buffers[self.bucket_name][key_path] = _to_bytes(contents=contents) return self.bucket_name, key_path, version def store_filepath( - self, namespace: str, filepath: str, overwrite: bool = False, dest_filepath: str = "" + self, + namespace: str, + filepath: str, + overwrite: bool = False, + dest_filepath: str = "", ) -> tuple[str, str]: with open(filepath, "rb") as f: if not dest_filepath: dest_filepath = os.path.basename(filepath) - return self.store_file(namespace, dest_filepath, f.read(), overwrite=overwrite) + return self.store_file( + namespace, dest_filepath, f.read(), overwrite=overwrite + ) - def fetch_file_contents(self, key_path: str, version_id: str | None = None) -> bytes: + def fetch_file_contents( + self, key_path: str, version_id: str | None = None + ) -> bytes: if version_id: versioned_bucket = self.versioned_buffers[self.bucket_name] - if key_path not in versioned_bucket or version_id not in versioned_bucket[key_path]: + if ( + key_path not in versioned_bucket + or version_id not in versioned_bucket[key_path] + ): raise KeyDoesNotExist( - "Key with path %s and version %s was not found" % key_path % version_id + "Key with path %s and version %s was not found" + % key_path + % version_id ) return versioned_bucket[key_path][version_id] else: @@ -1540,7 +1616,9 @@ def fetch_file_contents(self, key_path: str, version_id: str | None = None) -> b return bucket[key_path] def get_key(self, key_path: str, version_id: str | None = None) -> S3Object: - return S3Object(bucket_name=self.bucket_name, key=key_path, version_id=version_id) + return S3Object( + bucket_name=self.bucket_name, key=key_path, version_id=version_id + ) def get_file_type(self, key_path: str) -> str: mime = magic.Magic(mime=True) @@ -1548,7 +1626,9 @@ def get_file_type(self, key_path: str) -> str: def fetch_file(self, key_path: str, version_id: str | None = None) -> StreamingBody: raw_stream = io.BytesIO(self.fetch_file_contents(key_path, version_id)) - return StreamingBody(raw_stream=raw_stream, content_length=files.size(raw_stream)) + return StreamingBody( + raw_stream=raw_stream, content_length=files.size(raw_stream) + ) def fetch_url( self, @@ -1607,7 +1687,9 @@ def clear(self) -> None: for bucket in self.buffers.values(): bucket.clear() - def generate_presigned_post(self, *, key_path: str, expires_in: int = 60) -> PreSignedPost: + def generate_presigned_post( + self, *, key_path: str, expires_in: int = 60 + ) -> PreSignedPost: return PreSignedPost( # Resolves to a localdev/storage url url=reverse("fake-presigned-post-upload"), @@ -1631,7 +1713,9 @@ def get_last_modified(self, key_path: str) -> datetime.datetime: def store( - bucket_name: str, use_date_in_key_path: bool = True, set_acl_bucket_owner: bool = False + bucket_name: str, + use_date_in_key_path: bool = True, + set_acl_bucket_owner: bool = False, ) -> BaseS3FileStore: """ Return the appropriate storage instance for a given bucket. @@ -1664,7 +1748,9 @@ def user_documents(use_date_in_key_path: bool = True) -> BaseS3FileStore: """ Return the user documents store. """ - return store(settings.S3_USER_DOCUMENTS_BUCKET, use_date_in_key_path=use_date_in_key_path) + return store( + settings.S3_USER_DOCUMENTS_BUCKET, use_date_in_key_path=use_date_in_key_path + ) def archive(use_date_in_key_path: bool = True) -> BaseS3FileStore: @@ -1705,13 +1791,16 @@ def outbound_flow_store() -> BaseS3FileStore: storage_class = import_string(settings.STORAGE_BACKEND) if storage_class == S3FileStore: return storage_class( - bucket_name=settings.INTEGRATION_FLOW_S3_OUTBOUND_BUCKET, use_date_in_key_path=False + bucket_name=settings.INTEGRATION_FLOW_S3_OUTBOUND_BUCKET, + use_date_in_key_path=False, ) else: return storage_class(bucket_name=settings.INTEGRATION_FLOW_S3_OUTBOUND_BUCKET) -def from_uri(uri: str) -> FileSystemFileStore | S3SubdirectoryFileStore | MemoryFileStore: +def from_uri( + uri: str +) -> FileSystemFileStore | S3SubdirectoryFileStore | MemoryFileStore: """ :raises ValueError: if the URI does not contain a scheme for a supported storage system. """ diff --git a/xocto/types.py b/xocto/types.py index 2291e34..78ef8d1 100644 --- a/xocto/types.py +++ b/xocto/types.py @@ -22,7 +22,9 @@ # Helpers for declaring django relations on classes # These are one-item Unions so that mypy knows they are type aliases and not strings ForeignKey = Union["models.ForeignKey[Union[Model, Combinable], Model]"] -OptionalForeignKey = Union["models.ForeignKey[Union[Model, Combinable, None], Union[Model, None]]"] +OptionalForeignKey = Union[ + "models.ForeignKey[Union[Model, Combinable, None], Union[Model, None]]" +] OneToOneField = Union["models.OneToOneField[Union[Model, Combinable], Model]"] OptionalOneToOneField = Union[ diff --git a/xocto/urls.py b/xocto/urls.py index 5f1d3e9..e6c5c40 100644 --- a/xocto/urls.py +++ b/xocto/urls.py @@ -23,7 +23,9 @@ def pop_url_query_param(url: str, key: str) -> tuple[str, str | None]: ValueError: ... """ parsed_url = parse.urlparse(url) - query_dict = parse.parse_qs(parsed_url.query, keep_blank_values=True, errors="strict") + query_dict = parse.parse_qs( + parsed_url.query, keep_blank_values=True, errors="strict" + ) query_value_list = query_dict.pop(key, (None,)) if len(query_value_list) != 1: raise ValueError(f"Cannot pop multi-valued query param: Key {key!r} in {url!r}") @@ -107,11 +109,15 @@ def parse_file_destination_from_url(url: str) -> tuple[str, str, str]: else: parsed_url = parse.urlparse(url) destination_path = parsed_url.path - upload_path = os.path.abspath(os.path.join(destination_path, url_relative_upload_path)) + upload_path = os.path.abspath( + os.path.join(destination_path, url_relative_upload_path) + ) common_path = os.path.commonpath((destination_path, upload_path)) new_url = parse.urlunparse(parsed_url._replace(path=common_path)) new_url = _fix_url_scheme(old_url=url, new_url=new_url) - upload_path = os.path.abspath(os.path.join(destination_path, url_relative_upload_path)) + upload_path = os.path.abspath( + os.path.join(destination_path, url_relative_upload_path) + ) rel_destination_path = os.path.relpath(destination_path, common_path) rel_upload_path = os.path.relpath(upload_path, common_path) return new_url, rel_destination_path, rel_upload_path