From c84c2921e66fc5f78918f1b8bae82df4c41aa553 Mon Sep 17 00:00:00 2001 From: Alex Lowe Date: Tue, 10 Dec 2024 20:12:28 -0500 Subject: [PATCH] feat: add publisher gateway client This adds a very basic publisher gateway client that meets the needs of charmcraft's new 'create-track' command. The client cannot login yet and still depends on the old requests-based clients for many features. --- craft_store/__init__.py | 5 +- craft_store/candidauth.py | 72 ++++++++++ craft_store/errors.py | 24 +++- craft_store/publishergateway/__init__.py | 35 +++++ craft_store/publishergateway/_publishergw.py | 106 +++++++++++++++ craft_store/publishergateway/_request.py | 41 ++++++ craft_store/publishergateway/_response.py | 71 ++++++++++ pyproject.toml | 4 +- tests/integration/conftest.py | 27 +++- .../integration/publishergateway/__init__.py | 0 .../integration/publishergateway/test_read.py | 34 +++++ .../publishergateway/test_write.py | 95 +++++++++++++ tests/unit/conftest.py | 7 + tests/unit/test_candid_auth.py | 45 ++++++ tests/unit/test_publishergateway.py | 128 ++++++++++++++++++ 15 files changed, 685 insertions(+), 9 deletions(-) create mode 100644 craft_store/candidauth.py create mode 100644 craft_store/publishergateway/__init__.py create mode 100644 craft_store/publishergateway/_publishergw.py create mode 100644 craft_store/publishergateway/_request.py create mode 100644 craft_store/publishergateway/_response.py create mode 100644 tests/integration/publishergateway/__init__.py create mode 100644 tests/integration/publishergateway/test_read.py create mode 100644 tests/integration/publishergateway/test_write.py create mode 100644 tests/unit/test_candid_auth.py create mode 100644 tests/unit/test_publishergateway.py diff --git a/craft_store/__init__.py b/craft_store/__init__.py index cf411cd4..7b777a53 100644 --- a/craft_store/__init__.py +++ b/craft_store/__init__.py @@ -20,9 +20,10 @@ __version__ = "3.0.0" -from . import creds, endpoints, errors, models +from . import creds, endpoints, errors, models, publishergateway from .auth import Auth from .base_client import BaseClient +from .candidauth import CandidAuth from .developer_token_auth import DeveloperTokenAuth from .http_client import HTTPClient from .store_client import StoreClient @@ -33,8 +34,10 @@ "endpoints", "errors", "models", + "publishergateway", "Auth", "BaseClient", + "CandidAuth", "HTTPClient", "StoreClient", "UbuntuOneStoreClient", diff --git a/craft_store/candidauth.py b/craft_store/candidauth.py new file mode 100644 index 00000000..76d5c3e6 --- /dev/null +++ b/craft_store/candidauth.py @@ -0,0 +1,72 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +"""Client for making requests towards publisher gateway.""" + +from collections.abc import Generator +from logging import getLogger +from typing import Literal + +import httpx + +from craft_store import auth, creds, errors + +logger = getLogger(__name__) + + +class CandidAuth(httpx.Auth): + """Request authentication using developer token.""" + + def __init__( + self, + *, + auth: auth.Auth, + auth_type: Literal["bearer", "macaroon"] = "bearer", + ) -> None: + self._auth = auth + self._auth_type = auth_type + self._token: str | None = None + + def auth_flow( + self, + request: httpx.Request, + ) -> Generator[httpx.Request, httpx.Response, None]: + """Update request to include Authorization header.""" + if self._token is None: + logger.debug("Getting candid macaroon from keyring") + self._token = self.get_token_from_keyring() + + self._update_headers(request) + yield request + + def get_token_from_keyring(self) -> str: + """Get token stored in the credentials storage.""" + logger.debug("Getting candid from credential storage") + return creds.unmarshal_candid_credentials(self._auth.get_credentials()) + + def _update_headers(self, request: httpx.Request) -> None: + """Add token to the request.""" + logger.debug("Adding ephemeral token to request headers") + if self._token is None: + raise errors.DeveloperTokenUnavailableError( + message="Candid token is not available" + ) + request.headers["Authorization"] = self._format_auth_header() + + def _format_auth_header(self) -> str: + if self._auth_type == "bearer": + return f"Bearer {self._token}" + return f"Macaroon {self._token}" diff --git a/craft_store/errors.py b/craft_store/errors.py index af763a34..71218100 100644 --- a/craft_store/errors.py +++ b/craft_store/errors.py @@ -15,11 +15,13 @@ # along with this program. If not, see . """Craft Store errors.""" +from __future__ import annotations import contextlib import logging from typing import Any +import httpx import requests import urllib3 import urllib3.exceptions @@ -31,9 +33,19 @@ class CraftStoreError(Exception): """Base class error for craft-store.""" - def __init__(self, message: str, resolution: str | None = None) -> None: + def __init__( + self, + message: str, + details: str | None = None, + resolution: str | None = None, + store_errors: StoreErrorList | None = None, + ) -> None: super().__init__(message) + if not details: + details = str(store_errors) + self.details = details self.resolution = resolution + self.store_errors = store_errors class NetworkError(CraftStoreError): @@ -75,7 +87,7 @@ def __repr__(self) -> str: if code: code_list.append(code) - return "" + return f"" def __contains__(self, error_code: str) -> bool: return any(error.get("code") == error_code for error in self._error_list) @@ -111,7 +123,7 @@ def _get_raw_error_list(self) -> list[dict[str, str]]: return error_list - def __init__(self, response: requests.Response) -> None: + def __init__(self, response: requests.Response | httpx.Response) -> None: self.response = response try: @@ -126,9 +138,13 @@ def __init__(self, response: requests.Response) -> None: with contextlib.suppress(KeyError): message = "Store operation failed:\n" + str(self.error_list) if message is None: + if isinstance(response, httpx.Response): + reason = response.reason_phrase + else: + reason = response.reason message = ( "Issue encountered while processing your request: " - f"[{response.status_code}] {response.reason}." + f"[{response.status_code}] {reason}." ) super().__init__(message) diff --git a/craft_store/publishergateway/__init__.py b/craft_store/publishergateway/__init__.py new file mode 100644 index 00000000..f628f8a2 --- /dev/null +++ b/craft_store/publishergateway/__init__.py @@ -0,0 +1,35 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Package containing the Publisher Gateway client and relevant metadata.""" + +from ._request import ( + CreateTrackRequest, +) + +from ._response import ( + PackageMetadata, + PublisherMetadata, + TrackMetadata, +) +from ._publishergw import PublisherGateway + +__all__ = [ + "CreateTrackRequest", + "PackageMetadata", + "PublisherMetadata", + "TrackMetadata", + "PublisherGateway", +] diff --git a/craft_store/publishergateway/_publishergw.py b/craft_store/publishergateway/_publishergw.py new file mode 100644 index 00000000..a2d7bea3 --- /dev/null +++ b/craft_store/publishergateway/_publishergw.py @@ -0,0 +1,106 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Client for the publisher gateway.""" + +from typing import cast +import httpx + +from craft_store import auth, candidauth, errors + +from . import _request, _response + + +class PublisherGateway: + """Client for the publisher gateway. + + This class is a client wrapper for the Canonical Publisher Gateway. + The latest version of the server API can be seen at: https://api.charmhub.io/docs/ + + Each instance is only valid for one particular namespace. + """ + + def __init__(self, base_url: str, namespace: str, auth: auth.Auth) -> None: + self._namespace = namespace + self._client = httpx.Client( + base_url=base_url, + auth=candidauth.CandidAuth(auth=auth, auth_type="macaroon"), + ) + + @staticmethod + def _check_error(response: httpx.Response) -> None: + if response.is_success: + return + try: + error_response = response.json() + except Exception as exc: + raise errors.CraftStoreError( + f"Invalid response from server ({response.status_code})", + details=response.text, + ) from exc + error_list = error_response.get("error-list", []) + if response.status_code >= 500: + brief = f"Store had an error ({response.status_code})" + else: + brief = f"Error {response.status_code} returned from store" + if len(error_list) == 1: + brief = f"{brief}: {error_list[0].get('message')}" + else: + brief = f"{brief}. See log for details" + raise errors.CraftStoreError( + brief, store_errors=errors.StoreErrorList(error_list) + ) + + def get_package_metadata(self, name: str) -> _response.PackageMetadata: + """Get general metadata for a package. + + :param name: The name of the package to query. + :returns: A dictionary matching the result from the publisher gateway. + + API docs: https://api.charmhub.io/docs/default.html#package_metadata + """ + response = self._client.get( + url=f"/v1/{self._namespace}/{name}", + ) + self._check_error(response) + return cast(_response.PackageMetadata, response.json()["metadata"]) + + def create_tracks(self, name: str, *tracks: _request.CreateTrackRequest) -> int: + """Create one or more tracks in the store. + + :param name: The store name (i.e. the specific charm, snap or other package) + to which this track will be attached. + :param tracks: Each track is a dictionary mapping query values. + :returns: The number of tracks created by the store. + :raises: ValueError if a track name is invalid. + + API docs: https://api.charmhub.io/docs/default.html#create_tracks + """ + bad_track_names = { + track["name"] + for track in tracks + if not _request.TRACK_NAME_REGEX.match(track["name"]) + or len(track["name"]) > 28 + } + if bad_track_names: + bad_tracks = ", ".join(sorted(bad_track_names)) + raise ValueError(f"The following track names are invalid: {bad_tracks}") + + response = self._client.post( + f"/v1/{self._namespace}/{name}/tracks", json=tracks + ) + self._check_error(response) + + return int(response.json()["num-tracks-created"]) diff --git a/craft_store/publishergateway/_request.py b/craft_store/publishergateway/_request.py new file mode 100644 index 00000000..fd01bb56 --- /dev/null +++ b/craft_store/publishergateway/_request.py @@ -0,0 +1,41 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Request models for the publisher gateway.""" + +import re +from typing import Annotated, TypedDict + +import annotated_types +from typing_extensions import NotRequired + +TRACK_NAME_REGEX = re.compile(r"^[a-zA-Z0-9](?:[_.-]?[a-zA-Z0-9])*$") +"""A regular expression guarding track names. + +Retrieved from https://api.staging.charmhub.io/docs/default.html#create_tracks +""" + +CreateTrackRequest = TypedDict( + "CreateTrackRequest", + { + "name": Annotated[ + str, + annotated_types.Len(1, 28), + annotated_types.Predicate(lambda name: bool(TRACK_NAME_REGEX.match(name))), + ], + "version-pattern": NotRequired[str | None], + "automatic-phasing-percentage": NotRequired[str | None], + }, +) diff --git a/craft_store/publishergateway/_response.py b/craft_store/publishergateway/_response.py new file mode 100644 index 00000000..5f0c6e2d --- /dev/null +++ b/craft_store/publishergateway/_response.py @@ -0,0 +1,71 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Response models for the publisher gateway.""" + +from typing import TypedDict + +from typing_extensions import NotRequired + +PublisherMetadata = TypedDict( + "PublisherMetadata", + { + "display-name": str | None, + "email": NotRequired[str], + "id": str, + "username": str | None, + "validation": NotRequired[str], + }, +) + +TrackMetadata = TypedDict( + "TrackMetadata", + { + "name": str, + "version-pattern": str | None, + "automatic-phasing-percentage": float | None, + "created-at": str, + }, +) + + +PackageMetadata = TypedDict( + "PackageMetadata", + { + "authority": NotRequired[str | None], + "contact": NotRequired[str | None], + "default-track": NotRequired[str | None], + "description": NotRequired[str | None], + "id": str, + "links": NotRequired[list[str] | None], + "media": list[dict[str, str]], + "name": NotRequired[str | None], + "private": bool, + "publisher": PublisherMetadata, + "status": str, + "store": str, + "summary": NotRequired[str | None], + "title": NotRequired[str | None], + "track-guardrails": NotRequired[ + dict[ + str, + str, + ] + ], + "tracks": list[TrackMetadata] | None, + "type": str, + "website": NotRequired[str | None], + }, +) diff --git a/pyproject.toml b/pyproject.toml index 814ed389..dcea9ac6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ authors = [ {name = "Canonical Ltd.", email = "snapcraft@lists.snapcraft.io"}, ] dependencies = [ + "annotated-types>=0.6.0", "keyring>=23.0", "overrides>=7.0.0", "requests>=2.27.0", @@ -142,7 +143,8 @@ minversion = "7.0" testpaths = "tests" xfail_strict = true markers = [ - "disable_fake_keyring" + "disable_fake_keyring", + "slow: tests that take a long time", ] [tool.coverage.run] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4999c9ff..98c11184 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -21,15 +21,20 @@ import pytest import yaml -from craft_store import StoreClient, endpoints +from craft_store import StoreClient, auth, endpoints, publishergateway + + +@pytest.fixture(scope="session") +def charmhub_base_url() -> str: + return os.getenv("CRAFT_STORE_CHARMHUB", "https://api.staging.charmhub.io") @pytest.fixture -def charm_client(): +def charm_client(charmhub_base_url): """A common StoreClient for charms""" return StoreClient( application_name="integration-test", - base_url="https://api.staging.charmhub.io", + base_url=charmhub_base_url, storage_base_url="https://storage.staging.snapcraftcontent.com", endpoints=endpoints.CHARMHUB, user_agent="integration-tests", @@ -47,6 +52,22 @@ def charmhub_charm_name(): return os.getenv("CRAFT_STORE_TEST_CHARM", default="craft-store-test") +@pytest.fixture +def charmhub_auth(charmhub_base_url): + return auth.Auth( + application_name="craft-store-integration-tests", + host=charmhub_base_url, + environment_auth="CRAFT_STORE_CHARMCRAFT_CREDENTIALS", + ) + + +@pytest.fixture +def publisher_gateway(charmhub_base_url, charmhub_auth): + return publishergateway.PublisherGateway( + base_url=charmhub_base_url, namespace="charm", auth=charmhub_auth + ) + + @pytest.fixture def fake_charm_file(tmp_path, charmhub_charm_name): """Provide a fake charm to upload to charmhub.""" diff --git a/tests/integration/publishergateway/__init__.py b/tests/integration/publishergateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/publishergateway/test_read.py b/tests/integration/publishergateway/test_read.py new file mode 100644 index 00000000..1df570db --- /dev/null +++ b/tests/integration/publishergateway/test_read.py @@ -0,0 +1,34 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Tests that only involve reading from the store.""" + + +from craft_store import publishergateway + +from tests.integration.conftest import needs_charmhub_credentials + + +@needs_charmhub_credentials() +def test_get_package_metadata( + publisher_gateway: publishergateway.PublisherGateway, charmhub_charm_name: str +): + metadata = publisher_gateway.get_package_metadata(charmhub_charm_name) + assert metadata.get("name") == charmhub_charm_name + assert metadata.get("default-track") + assert len(metadata["id"]) == len("sCPqM62aJhbLUJmpPfFbsxbd2zpR6dcu") + assert metadata.get("default-track") in { + track["name"] for track in metadata.get("tracks") or [] + } diff --git a/tests/integration/publishergateway/test_write.py b/tests/integration/publishergateway/test_write.py new file mode 100644 index 00000000..ed3ae819 --- /dev/null +++ b/tests/integration/publishergateway/test_write.py @@ -0,0 +1,95 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Single-endpoint write tests (likely with a read query after).""" + + +import contextlib +import time + +import pytest +from craft_store import errors, publishergateway + +from tests.integration.conftest import needs_charmhub_credentials + + +@pytest.mark.slow +@needs_charmhub_credentials() +@pytest.mark.parametrize("version_pattern", [None, r"\d+"]) +@pytest.mark.parametrize("percentages", [None, 50]) +def test_create_tracks( + publisher_gateway: publishergateway.PublisherGateway, + charmhub_charm_name: str, + version_pattern, + percentages, +): + track_name = str(time.time_ns()) + + tracks_created = publisher_gateway.create_tracks( + charmhub_charm_name, + { + "name": track_name, + "version-pattern": version_pattern, + "automatic-phasing-percentage": percentages, + }, + ) + assert tracks_created == 1 + + metadata = publisher_gateway.get_package_metadata(charmhub_charm_name) + if "tracks" not in metadata or not metadata["tracks"]: + raise ValueError("No tracks returned from the store") + + for track in metadata["tracks"]: + if track["name"] != track_name: + continue + assert track["version-pattern"] == version_pattern + assert track["automatic-phasing-percentage"] == percentages + break + else: + raise ValueError(f"Track {track_name} created but not returned from the store.") + + +@pytest.mark.slow +@needs_charmhub_credentials() +def test_create_disallowed_track( + publisher_gateway: publishergateway.PublisherGateway, charmhub_charm_name: str +): + track_name = "disallowed" + + with pytest.raises(errors.CraftStoreError, match="Invalid track name") as exc_info: + publisher_gateway.create_tracks(charmhub_charm_name, {"name": track_name}) + + assert exc_info.value.store_errors is not None + assert "invalid-tracks" in exc_info.value.store_errors + + +@pytest.mark.slow +@needs_charmhub_credentials() +def test_create_existing_track( + publisher_gateway: publishergateway.PublisherGateway, charmhub_charm_name: str +): + track_name = "1" + + # Suppress the error because we don't care about the first time + with contextlib.suppress(errors.CraftStoreError): + publisher_gateway.create_tracks(charmhub_charm_name, {"name": track_name}) + + with pytest.raises( + errors.CraftStoreError, match="Conflicting track exists" + ) as exc_info: + publisher_gateway.create_tracks(charmhub_charm_name, {"name": track_name}) + + assert exc_info.value.store_errors is not None + assert "conflicting-tracks" in exc_info.value.store_errors diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 6f72900b..5cba9fb0 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -16,8 +16,10 @@ import datetime from typing import Any +from unittest import mock from unittest.mock import patch +import craft_store import pytest @@ -108,3 +110,8 @@ def new_auth(request) -> bool: :see: base_client.wrap_credentials() """ return request.param + + +@pytest.fixture +def mock_auth(): + return mock.Mock(spec=craft_store.Auth) diff --git a/tests/unit/test_candid_auth.py b/tests/unit/test_candid_auth.py new file mode 100644 index 00000000..e4994797 --- /dev/null +++ b/tests/unit/test_candid_auth.py @@ -0,0 +1,45 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -* +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +# +"""Tests for authorizing requests using CandidAuth.""" + + +import httpx +import pytest +from craft_store import CandidAuth + + +@pytest.fixture +def candid_auth(mock_auth): + return CandidAuth( + auth=mock_auth, + ) + + +def test_get_token_from_keyring(mock_auth, candid_auth): + mock_auth.get_credentials.return_value = "{}" + + assert candid_auth.get_token_from_keyring() == "{}" + + +def test_auth_flow(mock_auth, candid_auth): + mock_auth.get_credentials.return_value = "{}" + + request = httpx.Request("GET", "http://localhost") + + next(candid_auth.auth_flow(request)) + + assert request.headers["Authorization"] == "Bearer {}" diff --git a/tests/unit/test_publishergateway.py b/tests/unit/test_publishergateway.py new file mode 100644 index 00000000..261e8d5c --- /dev/null +++ b/tests/unit/test_publishergateway.py @@ -0,0 +1,128 @@ +# -*- Mode:Python; indent-tabs-mode:nil; tab-width:4 -*- +# +# Copyright 2024 Canonical Ltd. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License version 3 as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Unit tests for the publisher gateway.""" + +from unittest import mock + +import httpx +import pytest +from craft_store import errors, publishergateway + + +@pytest.fixture +def mock_httpx_client(): + return mock.Mock(spec=httpx.Client) + + +@pytest.fixture +def publisher_gateway(mock_httpx_client): + gw = publishergateway.PublisherGateway("http://localhost", "charm", mock.Mock()) + gw._client = mock_httpx_client + return gw + + +@pytest.mark.parametrize("response", [httpx.Response(status_code=204)]) +def test_check_error_on_success(response: httpx.Response): + assert publishergateway.PublisherGateway._check_error(response) is None + + +@pytest.mark.parametrize( + ("response", "match"), + [ + pytest.param( + httpx.Response(503, text="help!"), + r"Invalid response from server \(503\)", + id="really-bad", + ), + pytest.param( + httpx.Response( + 503, + json={"error-list": [{"code": "whelp", "message": "we done goofed"}]}, + ), + r"Store had an error \(503\): we done goofed", + id="server-error", + ), + pytest.param( + httpx.Response( + 400, + json={"error-list": [{"code": "whelp", "message": "you messed up"}]}, + ), + r"Error 400 returned from store: you messed up", + id="client-error", + ), + pytest.param( + httpx.Response( + 418, + json={ + "error-list": [ + {"code": "whelp", "message": "I am a teapot"}, + { + "code": "bad", + "message": "Why would you ask me for a coffee?", + }, + ] + }, + ), + r"Error 418 returned from store. See log for details", + id="multiple-client-errors", + ), + ], +) +def test_check_error(response: httpx.Response, match): + with pytest.raises(errors.CraftStoreError, match=match): + publishergateway.PublisherGateway._check_error(response) + + +def test_get_package_metadata( + mock_httpx_client: mock.Mock, publisher_gateway: publishergateway.PublisherGateway +): + mock_httpx_client.get.return_value = httpx.Response( + 200, json={"metadata": {"meta": "data"}} + ) + + assert publisher_gateway.get_package_metadata("my-package") == {"meta": "data"} + + mock_httpx_client.get.assert_called_once_with(url="/v1/charm/my-package") + + +@pytest.mark.parametrize( + ("tracks", "match"), + [ + ([{"name": "-"}], ": -$"), + ( + [{"name": "123456789012345678901234567890"}], + ": 123456789012345678901234567890$", + ), + ([{"name": "-"}, {"name": "_!"}], ": -, _!$"), + ], +) +def test_create_tracks_validation( + publisher_gateway: publishergateway.PublisherGateway, + tracks, + match, +): + with pytest.raises(ValueError, match=match): + publisher_gateway.create_tracks("my-name", *tracks) + + +def test_create_tracks_success( + mock_httpx_client: mock.Mock, publisher_gateway: publishergateway.PublisherGateway +): + mock_httpx_client.post.return_value = httpx.Response( + 200, json={"num-tracks-created": 0} + ) + + assert publisher_gateway.create_tracks("my-name") == 0