Skip to content

Commit

Permalink
fix: pr suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
lengau committed Dec 12, 2024
1 parent 83764a9 commit 771e63c
Show file tree
Hide file tree
Showing 14 changed files with 62 additions and 35 deletions.
5 changes: 3 additions & 2 deletions craft_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
__version__ = "3.0.0"


from . import creds, endpoints, errors, models, publishergateway
from . import creds, endpoints, errors, models
from ._httpx_auth import CandidAuth, DeveloperTokenAuth
from .publisher import PublisherGateway
from .auth import Auth
from .base_client import BaseClient
from .http_client import HTTPClient
Expand All @@ -33,7 +34,7 @@
"endpoints",
"errors",
"models",
"publishergateway",
"PublisherGateway",
"Auth",
"BaseClient",
"CandidAuth",
Expand Down
9 changes: 6 additions & 3 deletions craft_store/_httpx_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def _update_headers(self, request: httpx.Request) -> None:
"""Add token to the request."""
logger.debug("Adding ephemeral token to request headers")
if self._token is None:
raise errors.DeveloperTokenUnavailableError(
message="Token is not available"
)
raise errors.AuthTokenUnavailableError(message="Token is not available")
request.headers["Authorization"] = self._format_auth_header()

def _format_auth_header(self) -> str:
Expand All @@ -73,6 +71,11 @@ def _format_auth_header(self) -> str:
class CandidAuth(_TokenAuth):
"""Candid based authentication class for httpx store clients."""

def __init__(
self, *, auth: auth.Auth, auth_type: Literal["bearer", "macaroon"] = "macaroon"
) -> None:
super().__init__(auth=auth, auth_type=auth_type)

def get_token_from_keyring(self) -> str:
"""Get token stored in the credentials storage."""
logger.debug("Getting candid from credential storage")
Expand Down
18 changes: 15 additions & 3 deletions craft_store/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,25 @@ def __init__(
store_errors: StoreErrorList | None = None,
) -> None:
super().__init__(message)
if not details:
if store_errors and not details:
details = str(store_errors)
self.details = details
self.resolution = resolution
self.store_errors = store_errors


class InvalidRequestError(CraftStoreError, ValueError):
"""Error when the request is invalid in a known way."""

def __init__(
self,
message: str,
details: str | None = None,
resolution: str | None = None,
) -> None:
super().__init__(message, details, resolution)


class NetworkError(CraftStoreError):
"""Error to raise on network or infrastructure issues.
Expand Down Expand Up @@ -209,5 +221,5 @@ def __init__(self, url: str) -> None:
super().__init__(f"Empty token value returned from {url!r}.")


class DeveloperTokenUnavailableError(CraftStoreError):
"""Raised when developer token is not set."""
class AuthTokenUnavailableError(CraftStoreError):
"""Raised when an authorization token is not available."""
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Client for the publisher gateway."""
from __future__ import annotations

from json import JSONDecodeError
from typing import cast

import httpx
Expand Down Expand Up @@ -48,7 +50,7 @@ def _check_error(response: httpx.Response) -> None:
return
try:
error_response = response.json()
except Exception as exc:
except JSONDecodeError as exc:
raise errors.CraftStoreError(
f"Invalid response from server ({response.status_code})",
details=response.text,
Expand All @@ -61,7 +63,8 @@ def _check_error(response: httpx.Response) -> None:
if len(error_list) == 1:
brief = f"{brief}: {error_list[0].get('message')}"
else:
brief = f"{brief}. See log for details"
fancy_error_list = errors.StoreErrorList(error_list)
brief = f"{brief}.\n{fancy_error_list}"
raise errors.CraftStoreError(
brief, store_errors=errors.StoreErrorList(error_list)
)
Expand All @@ -87,7 +90,7 @@ def create_tracks(self, name: str, *tracks: _request.CreateTrackRequest) -> int:
to which this track will be attached.
:param tracks: Each track is a dictionary mapping query values.
:returns: The number of tracks created by the store.
:raises: ValueError if a track name is invalid.
:returns: InvalidRequestError if the name field of any passed track is invalid.
API docs: https://api.charmhub.io/docs/default.html#create_tracks
"""
Expand All @@ -99,7 +102,10 @@ def create_tracks(self, name: str, *tracks: _request.CreateTrackRequest) -> int:
}
if bad_track_names:
bad_tracks = ", ".join(sorted(bad_track_names))
raise ValueError(f"The following track names are invalid: {bad_tracks}")
raise errors.InvalidRequestError(
f"The following track names are invalid: {bad_tracks}",
resolution="Ensure all tracks have valid names.",
)

response = self._client.post(
f"/v1/{self._namespace}/{name}/tracks", json=tracks
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest
import yaml
from craft_store import StoreClient, auth, endpoints, publishergateway
from craft_store import StoreClient, auth, endpoints, publisher


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -63,7 +63,7 @@ def charmhub_auth(charmhub_base_url):

@pytest.fixture
def publisher_gateway(charmhub_base_url, charmhub_auth):
return publishergateway.PublisherGateway(
return publisher.PublisherGateway(
base_url=charmhub_base_url, namespace="charm", auth=charmhub_auth
)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
"""Tests that only involve reading from the store."""


from craft_store import publishergateway
from craft_store import publisher

from tests.integration.conftest import needs_charmhub_credentials


@needs_charmhub_credentials()
def test_get_package_metadata(
publisher_gateway: publishergateway.PublisherGateway, charmhub_charm_name: str
publisher_gateway: publisher.PublisherGateway, charmhub_charm_name: str
):
metadata = publisher_gateway.get_package_metadata(charmhub_charm_name)
assert metadata.get("name") == charmhub_charm_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time

import pytest
from craft_store import errors, publishergateway
from craft_store import errors, publisher

from tests.integration.conftest import needs_charmhub_credentials

Expand All @@ -30,7 +30,7 @@
@pytest.mark.parametrize("version_pattern", [None, r"\d+"])
@pytest.mark.parametrize("percentages", [None, 50])
def test_create_tracks(
publisher_gateway: publishergateway.PublisherGateway,
publisher_gateway: publisher.PublisherGateway,
charmhub_charm_name: str,
version_pattern,
percentages,
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_create_tracks(
@pytest.mark.slow
@needs_charmhub_credentials()
def test_create_disallowed_track(
publisher_gateway: publishergateway.PublisherGateway, charmhub_charm_name: str
publisher_gateway: publisher.PublisherGateway, charmhub_charm_name: str
):
track_name = "disallowed"

Expand All @@ -78,7 +78,7 @@ def test_create_disallowed_track(
@pytest.mark.slow
@needs_charmhub_credentials()
def test_create_existing_track(
publisher_gateway: publishergateway.PublisherGateway, charmhub_charm_name: str
publisher_gateway: publisher.PublisherGateway, charmhub_charm_name: str
):
track_name = "1"

Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Unit tests for the publisher gateway."""

import textwrap
from unittest import mock

import httpx
import pytest
from craft_store import errors, publishergateway
from craft_store import errors, publisher


@pytest.fixture
Expand All @@ -29,14 +30,14 @@ def mock_httpx_client():

@pytest.fixture
def publisher_gateway(mock_httpx_client):
gw = publishergateway.PublisherGateway("http://localhost", "charm", mock.Mock())
gw = publisher.PublisherGateway("http://localhost", "charm", mock.Mock())
gw._client = mock_httpx_client
return gw


@pytest.mark.parametrize("response", [httpx.Response(status_code=204)])
def test_check_error_on_success(response: httpx.Response):
assert publishergateway.PublisherGateway._check_error(response) is None
assert publisher.PublisherGateway._check_error(response) is None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -68,26 +69,31 @@ def test_check_error_on_success(response: httpx.Response):
418,
json={
"error-list": [
{"code": "whelp", "message": "I am a teapot"},
{"code": "good", "message": "I am a teapot"},
{
"code": "bad",
"message": "Why would you ask me for a coffee?",
"message": "Why would you ask me for coffee?",
},
]
},
),
r"Error 418 returned from store. See log for details",
textwrap.dedent(
"""\
Error 418 returned from store.
- good: I am a teapot
- bad: Why would you ask me for coffee?"""
),
id="multiple-client-errors",
),
],
)
def test_check_error(response: httpx.Response, match):
with pytest.raises(errors.CraftStoreError, match=match):
publishergateway.PublisherGateway._check_error(response)
publisher.PublisherGateway._check_error(response)


def test_get_package_metadata(
mock_httpx_client: mock.Mock, publisher_gateway: publishergateway.PublisherGateway
mock_httpx_client: mock.Mock, publisher_gateway: publisher.PublisherGateway
):
mock_httpx_client.get.return_value = httpx.Response(
200, json={"metadata": {"meta": "data"}}
Expand All @@ -110,7 +116,7 @@ def test_get_package_metadata(
],
)
def test_create_tracks_validation(
publisher_gateway: publishergateway.PublisherGateway,
publisher_gateway: publisher.PublisherGateway,
tracks,
match,
):
Expand All @@ -119,7 +125,7 @@ def test_create_tracks_validation(


def test_create_tracks_success(
mock_httpx_client: mock.Mock, publisher_gateway: publishergateway.PublisherGateway
mock_httpx_client: mock.Mock, publisher_gateway: publisher.PublisherGateway
):
mock_httpx_client.post.return_value = httpx.Response(
200, json={"num-tracks-created": 0}
Expand Down
9 changes: 4 additions & 5 deletions tests/unit/test_httpx_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_candid_auth_flow(mock_auth, candid_auth):

next(candid_auth.auth_flow(request))

assert request.headers["Authorization"] == "Bearer {}"
assert request.headers["Authorization"] == "Macaroon {}"


@pytest.fixture
Expand Down Expand Up @@ -122,10 +122,9 @@ def test_auth_if_token_unset(
mocker.patch.object(
developer_token_auth, "get_token_from_keyring", return_value=None
)
httpx_client = httpx.Client(auth=developer_token_auth)

client = httpx.Client(auth=developer_token_auth)
with pytest.raises(
errors.DeveloperTokenUnavailableError,
errors.AuthTokenUnavailableError,
match="Token is not available",
):
httpx_client.request("GET", "https://fake-testcraft-url.localhost")
client.request("GET", "https://fake-testcraft-url.localhost")

0 comments on commit 771e63c

Please sign in to comment.