From aa30e4ffc92edaa498231ae33eba74dd2a46eb49 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 29 Mar 2023 14:33:05 +0200 Subject: [PATCH 01/19] added optional event params for relation events --- scenario/runtime.py | 31 ++++++++++++- scenario/state.py | 47 ++++++++++++++++---- tests/test_e2e/test_relations.py | 76 +++++++++++++++++++++++++++++++- tox.ini | 2 + 4 files changed, 145 insertions(+), 11 deletions(-) diff --git a/scenario/runtime.py b/scenario/runtime.py index 35757f3b..b55fc83c 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -157,7 +157,8 @@ def _cleanup_env(env): # running this in a clean venv or a container anyway. # cleanup env, in case we'll be firing multiple events, we don't want to accumulate. for key in env: - os.unsetenv(key) + # os.unsetenv does not work !? + del os.environ[key] def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if event.name.endswith("_action"): @@ -183,9 +184,34 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): { "JUJU_RELATION": relation.endpoint, "JUJU_RELATION_ID": str(relation.relation_id), + "JUJU_REMOTE_APP": relation.remote_app_name, } ) + if event._is_relation_event: # noqa + remote_unit_id = event.relation_remote_unit_id + if not remote_unit_id: + if len(relation.remote_unit_ids) == 1: + remote_unit_id = relation.remote_unit_ids[0] + logger.info( + "there's only one remote unit, so we set JUJU_REMOTE_UNIT to it, " + "but you probably should be parametrizing the event with `remote_unit` " + "to be explicit." + ) + else: + logger.warning( + "unable to determine remote unit ID; which means JUJU_REMOTE_UNIT will " + "be unset and you might get error if charm code attempts to access " + "`event.unit` in event handlers. \n" + "If that is the case, pass `remote_unit` to the Event constructor." + ) + + if remote_unit_id: + remote_unit = f"{relation.remote_app_name}/{remote_unit_id}" + env["JUJU_REMOTE_UNIT"] = remote_unit + if event.name.endswith("_relation_departed"): + env["JUJU_DEPARTING_UNIT"] = remote_unit + if container := event.container: env.update({"JUJU_WORKLOAD_NAME": container.name}) @@ -348,8 +374,9 @@ def exec( finally: logger.info(" - Exited ops.main.") - logger.info(" - clearing env") + logger.info(" - Clearing env") self._cleanup_env(env) + assert not os.getenv("JUJU_DEPARTING_UNIT") logger.info(" - closing storage") output_state = self._close_storage(output_state, temporary_charm_root) diff --git a/scenario/state.py b/scenario/state.py index 50160cb1..1d218aed 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -144,6 +144,27 @@ def normalize_name(s: str): return s.replace("-", "_") +class ParametrizedEvent: + def __init__(self, accept_params: Tuple[str], category: str, *args, **kwargs): + self._accept_params = accept_params + self._category = category + self._args = args + self._kwargs = kwargs + + def __call__(self, remote_unit: Optional[str] = None) -> "Event": + """Construct an Event object using the arguments provided at init and any extra params.""" + if remote_unit and "remote_unit" not in self._accept_params: + raise ValueError( + f"cannot pass param `remote_unit` to a " + f"{self._category} event constructor." + ) + + return Event(*self._args, *self._kwargs, relation_remote_unit=remote_unit) + + def deferred(self, handler: Callable, event_id: int = 1) -> "DeferredEvent": + return self().deferred(handler=handler, event_id=event_id) + + @dataclasses.dataclass class Relation(_DCBase): endpoint: str @@ -191,35 +212,35 @@ def __post_init__(self): self.remote_units_data = {0: {}} @property - def changed_event(self): + def changed_event(self) -> "Event": """Sugar to generate a -relation-changed event.""" return Event( name=normalize_name(self.endpoint + "-relation-changed"), relation=self ) @property - def joined_event(self): + def joined_event(self) -> "Event": """Sugar to generate a -relation-joined event.""" return Event( name=normalize_name(self.endpoint + "-relation-joined"), relation=self ) @property - def created_event(self): + def created_event(self) -> "Event": """Sugar to generate a -relation-created event.""" return Event( name=normalize_name(self.endpoint + "-relation-created"), relation=self ) @property - def departed_event(self): + def departed_event(self) -> "Event": """Sugar to generate a -relation-departed event.""" return Event( name=normalize_name(self.endpoint + "-relation-departed"), relation=self ) @property - def broken_event(self): + def broken_event(self) -> "Event": """Sugar to generate a -relation-broken event.""" return Event( name=normalize_name(self.endpoint + "-relation-broken"), relation=self @@ -721,6 +742,8 @@ class Event(_DCBase): # if this is a relation event, the relation it refers to relation: Optional[Relation] = None + # and the name of the remote unit this relation event is about + relation_remote_unit_id: Optional[int] = None # if this is a secret event, the secret it refers to secret: Optional[Secret] = None @@ -733,6 +756,14 @@ class Event(_DCBase): # - pebble? # - action? + def __call__(self, remote_unit: Optional[int] = None) -> "Event": + if remote_unit and not self._is_relation_event: + raise ValueError( + "cannot pass param `remote_unit` to a " + "non-relation event constructor." + ) + return self.replace(relation_remote_unit_id=remote_unit) + def __post_init__(self): if "-" in self.name: logger.warning(f"Only use underscores in event names. {self.name!r}") @@ -834,9 +865,9 @@ def deferred(self, handler: Callable, event_id: int = 1) -> DeferredEvent: # this is a RelationEvent. The snapshot: snapshot_data = { "relation_name": self.relation.endpoint, - "relation_id": self.relation.relation_id - # 'app_name': local app name - # 'unit_name': local unit name + "relation_id": self.relation.relation_id, + "app_name": self.relation.remote_app_name, + "unit_name": f"{self.relation.remote_app_name}/{self.relation_remote_unit_id}", } return DeferredEvent( diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 58dd3205..9bbc8502 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -1,7 +1,8 @@ +import os from typing import Type import pytest -from ops.charm import CharmBase, CharmEvents +from ops.charm import CharmBase, CharmEvents, RelationDepartedEvent from ops.framework import EventBase, Framework from scenario.state import Relation, State @@ -124,3 +125,76 @@ def callback(charm: CharmBase, _): }, }, ) + + +@pytest.mark.parametrize( + "evt_name", + ("changed", "broken", "departed", "joined", "created"), +) +@pytest.mark.parametrize( + "remote_app_name", + ("remote", "prometheus", "aodeok123"), +) +def test_relation_events_attrs(mycharm, evt_name, remote_app_name): + relation = Relation( + endpoint="foo", interface="foo", remote_app_name=remote_app_name + ) + + def callback(charm: CharmBase, event): + assert event.app + assert event.unit + if isinstance(event, RelationDepartedEvent): + assert event.departing_unit + + mycharm._call = callback + + State( + relations=[ + relation, + ], + ).trigger( + getattr(relation, f"{evt_name}_event")(remote_unit=1), + mycharm, + meta={ + "name": "local", + "requires": { + "foo": {"interface": "foo"}, + }, + }, + ) + + +@pytest.mark.parametrize( + "evt_name", + ("changed", "broken", "departed", "joined", "created"), +) +@pytest.mark.parametrize( + "remote_app_name", + ("remote", "prometheus", "aodeok123"), +) +def test_relation_events_no_attrs(mycharm, evt_name, remote_app_name): + relation = Relation( + endpoint="foo", interface="foo", remote_app_name=remote_app_name + ) + + def callback(charm: CharmBase, event): + assert event.app # that's always present + assert not event.unit + assert not getattr(event, "departing_unit", False) + + mycharm._call = callback + + State( + relations=[ + relation, + ], + ).trigger( + getattr(relation, f"{evt_name}_event"), + mycharm, + meta={ + "name": "local", + "requires": { + "foo": {"interface": "foo"}, + }, + }, + ) diff --git a/tox.ini b/tox.ini index 8689e828..a25b8f0f 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ commands = [testenv:lint] +skip_install=True description = lint deps = coverage[toml] @@ -41,6 +42,7 @@ commands = [testenv:fmt] +skip_install=True description = Format code deps = black From e640ec9edbeb981b4016547a7466e43c26f3a69a Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 29 Mar 2023 14:37:58 +0200 Subject: [PATCH 02/19] fixed truthiness --- scenario/runtime.py | 4 ++-- tests/test_e2e/test_relations.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/scenario/runtime.py b/scenario/runtime.py index b55fc83c..cbf89e83 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -190,7 +190,7 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if event._is_relation_event: # noqa remote_unit_id = event.relation_remote_unit_id - if not remote_unit_id: + if remote_unit_id is None: # don't check truthiness because it could be int(0) if len(relation.remote_unit_ids) == 1: remote_unit_id = relation.remote_unit_ids[0] logger.info( @@ -206,7 +206,7 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): "If that is the case, pass `remote_unit` to the Event constructor." ) - if remote_unit_id: + if remote_unit_id is not None: remote_unit = f"{relation.remote_app_name}/{remote_unit_id}" env["JUJU_REMOTE_UNIT"] = remote_unit if event.name.endswith("_relation_departed"): diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 9bbc8502..4551de0c 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -135,7 +135,11 @@ def callback(charm: CharmBase, _): "remote_app_name", ("remote", "prometheus", "aodeok123"), ) -def test_relation_events_attrs(mycharm, evt_name, remote_app_name): +@pytest.mark.parametrize( + "remote_unit_id", + (0, 1), +) +def test_relation_events_attrs(mycharm, evt_name, remote_app_name, remote_unit_id): relation = Relation( endpoint="foo", interface="foo", remote_app_name=remote_app_name ) @@ -153,7 +157,7 @@ def callback(charm: CharmBase, event): relation, ], ).trigger( - getattr(relation, f"{evt_name}_event")(remote_unit=1), + getattr(relation, f"{evt_name}_event")(remote_unit=remote_unit_id), mycharm, meta={ "name": "local", From 33675481b200e7c2b02a798679cf128ab1a4f9a9 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 29 Mar 2023 14:40:59 +0200 Subject: [PATCH 03/19] vbump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f93c900e..f8936873 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "ops-scenario" -version = "2.1.3.3" +version = "2.1.3.4" authors = [ { name = "Pietro Pasotti", email = "pietro.pasotti@canonical.com" } ] From 339d332b41eeabe71a08e1e638c034a646de1a11 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 29 Mar 2023 14:49:59 +0200 Subject: [PATCH 04/19] readme --- README.md | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/README.md b/README.md index 62dbcf0f..0ab90774 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,40 @@ def test_relation_data(): # which is very idiomatic and superbly explicit. Noice. ``` +If you want to trigger relation events, the easiest way to do so is get a hold of the Relation instance and grab the event from one of its aptly-named properties: + +```python +from scenario import Relation +relation = Relation(endpoint="foo", interface="bar") +changed_event = relation.changed_event +joined_event = relation.joined_event +# ... +``` + +This is in fact syntactic sugar for: +```python +from scenario import Relation, Event +relation = Relation(endpoint="foo", interface="bar") +changed_event = Event('foo-relation-changed', relation=relation) +``` + +The reason for this construction is that the event is associated with some relation-specific metadata, that Scenario needs to set up the process that will run `ops.main` with the right environment variables. + +### Additional event parameters +All relation events have some additional metadata that does not belong in the Relation object, such as, for a relation-joined event, the name of the (remote) unit that is joining the relation. That is what determines what `ops.model.Unit` you get when you get `RelationJoinedEvent().unit` in an event handler. + +In order to supply this parameter, you will have to **call** the event object and pass as `remote_unit` the id of the remote unit that the event is about. + +```python +from scenario import Relation, Event +relation = Relation(endpoint="foo", interface="bar") +remote_unit_2_is_joining_event = relation.joined_event(remote_unit=2) + +# which is syntactic sugar for: +remote_unit_2_is_joining_event = Event('foo-relation-changed', relation=relation, relation_remote_unit_id=2) +``` + + ## Containers When testing a kubernetes charm, you can mock container interactions. From 3e86fb4931ba6c5ca5f1f97492addd9d2cf70720 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 29 Mar 2023 15:06:34 +0200 Subject: [PATCH 05/19] lint --- scenario/runtime.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scenario/runtime.py b/scenario/runtime.py index cbf89e83..b8668b59 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -190,7 +190,9 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if event._is_relation_event: # noqa remote_unit_id = event.relation_remote_unit_id - if remote_unit_id is None: # don't check truthiness because it could be int(0) + if ( + remote_unit_id is None + ): # don't check truthiness because it could be int(0) if len(relation.remote_unit_ids) == 1: remote_unit_id = relation.remote_unit_ids[0] logger.info( From e9b0b3e254d89c56d3c547511bda36a9e6740dfc Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 29 Mar 2023 15:35:59 +0200 Subject: [PATCH 06/19] utest fix --- tests/test_e2e/test_relations.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 4551de0c..38663e42 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -176,9 +176,12 @@ def callback(charm: CharmBase, event): "remote_app_name", ("remote", "prometheus", "aodeok123"), ) -def test_relation_events_no_attrs(mycharm, evt_name, remote_app_name): +def test_relation_events_no_attrs(mycharm, evt_name, remote_app_name, caplog): relation = Relation( - endpoint="foo", interface="foo", remote_app_name=remote_app_name + endpoint="foo", + interface="foo", + remote_app_name=remote_app_name, + remote_units_data={0: {}, 1: {}}, # 2 units ) def callback(charm: CharmBase, event): @@ -202,3 +205,5 @@ def callback(charm: CharmBase, event): }, }, ) + + assert "unable to determine remote unit ID" in caplog.text From 6d8cd70367df8d19ecfcb65cd39e5da2c19a33eb Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Thu, 30 Mar 2023 12:31:44 +0200 Subject: [PATCH 07/19] added databag validators --- scenario/state.py | 26 +++++++++++++++++++++++++- tests/test_e2e/test_relations.py | 17 ++++++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/scenario/state.py b/scenario/state.py index 1d218aed..bab77dda 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -63,6 +63,13 @@ } +class StateValidationError(RuntimeError): + """Raised when individual parts of the State are inconsistent.""" + + # as opposed to InconsistentScenario error where the + # **combination** of several parts of the State are. + + @dataclasses.dataclass class _DCBase: def replace(self, *args, **kwargs): @@ -200,7 +207,7 @@ def __post_init__(self): if self.remote_unit_ids and self.remote_units_data: if not set(self.remote_unit_ids) == set(self.remote_units_data): - raise ValueError( + raise StateValidationError( f"{self.remote_unit_ids} should include any and all IDs from {self.remote_units_data}" ) elif self.remote_unit_ids: @@ -211,6 +218,23 @@ def __post_init__(self): self.remote_unit_ids = [0] self.remote_units_data = {0: {}} + for databag in ( + self.local_unit_data, + self.local_app_data, + self.remote_app_data, + *self.remote_units_data.values(), + ): + if not isinstance(databag, dict): + raise StateValidationError( + f"all databags should be dicts, not {type(databag)}" + ) + for k, v in databag.items(): + if not isinstance(v, str): + raise StateValidationError( + f"all databags should be Dict[str,str]; " + f"found a value of type {type(v)}" + ) + @property def changed_event(self) -> "Event": """Sugar to generate a -relation-changed event.""" diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 38663e42..41541e34 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -5,7 +5,8 @@ from ops.charm import CharmBase, CharmEvents, RelationDepartedEvent from ops.framework import EventBase, Framework -from scenario.state import Relation, State +from scenario.runtime import InconsistentScenarioError +from scenario.state import Relation, State, StateValidationError @pytest.fixture(scope="function") @@ -207,3 +208,17 @@ def callback(charm: CharmBase, event): ) assert "unable to determine remote unit ID" in caplog.text + + +@pytest.mark.parametrize("data", (set(), {}, [], (), 1, 1.0, None, b"")) +def test_relation_unit_data_bad_types(mycharm, data): + with pytest.raises(StateValidationError): + relation = Relation( + endpoint="foo", interface="foo", remote_units_data={0: {"a": data}} + ) + + +@pytest.mark.parametrize("data", (set(), {}, [], (), 1, 1.0, None, b"")) +def test_relation_app_data_bad_types(mycharm, data): + with pytest.raises(StateValidationError): + relation = Relation(endpoint="foo", interface="foo", local_app_data={"a": data}) From 7d6e5a24062cf3db59327680d1239cbd6254c55c Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Thu, 30 Mar 2023 14:42:18 +0200 Subject: [PATCH 08/19] added sub and peer relation types --- scenario/consistency_checker.py | 16 ++- scenario/mocking.py | 34 +++++- scenario/state.py | 185 ++++++++++++++++++++++++------- tests/test_e2e/test_relations.py | 21 +++- 4 files changed, 206 insertions(+), 50 deletions(-) diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 0c1b6759..520fcf43 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -3,7 +3,7 @@ from scenario.runtime import InconsistentScenarioError from scenario.runtime import logger as scenario_logger -from scenario.state import _CharmSpec, normalize_name +from scenario.state import SubordinateRelation, _CharmSpec, normalize_name if TYPE_CHECKING: from scenario.state import Event, State @@ -51,6 +51,7 @@ def check_consistency( check_config_consistency, check_event_consistency, check_secrets_consistency, + check_relation_consistency, ): results = check( state=state, event=event, charm_spec=charm_spec, juju_version=juju_version @@ -179,6 +180,19 @@ def check_secrets_consistency( return Results(errors, []) +def check_relation_consistency( + *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs +) -> Results: + errors = [] + for relation in state.relations: + if isinstance(relation, SubordinateRelation): + # todo: verify that this unit's id is not in: + # relation.remote_unit_id + pass + + return Results(errors, []) + + def check_containers_consistency( *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: diff --git a/scenario/mocking.py b/scenario/mocking.py index 80069804..04f0e9b0 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -16,7 +16,15 @@ if TYPE_CHECKING: from scenario.state import Container as ContainerSpec - from scenario.state import Event, ExecOutput, State, _CharmSpec + from scenario.state import ( + Event, + ExecOutput, + PeerRelation, + Relation, + State, + SubordinateRelation, + _CharmSpec, + ) logger = scenario_logger.getChild("mocking") @@ -62,7 +70,9 @@ def get_pebble(self, socket_path: str) -> "Client": charm_spec=self._charm_spec, ) - def _get_relation_by_id(self, rel_id): + def _get_relation_by_id( + self, rel_id + ) -> Union["Relation", "SubordinateRelation", "PeerRelation"]: try: return next( filter(lambda r: r.relation_id == rel_id, self._state.relations) @@ -121,10 +131,22 @@ def relation_ids(self, relation_name): def relation_list(self, relation_id: int): relation = self._get_relation_by_id(relation_id) - return tuple( - f"{relation.remote_app_name}/{unit_id}" - for unit_id in relation.remote_unit_ids - ) + relation_type = getattr(relation, "__type__", "") + if relation_type == "regular": + return tuple( + f"{relation.remote_app_name}/{unit_id}" + for unit_id in relation.remote_unit_ids + ) + elif relation_type == "peer": + return tuple(f"{self.app_name}/{unit_id}" for unit_id in relation.peers_ids) + + elif relation_type == "subordinate": + return tuple(f"{relation.primary_name}") + else: + raise RuntimeError( + f"Invalid relation type: {relation_type}; should be one of " + f"scenario.state.RelationType" + ) def config_get(self): state_config = self._state.config diff --git a/scenario/state.py b/scenario/state.py index bab77dda..20b384ec 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -7,6 +7,7 @@ import inspect import re import typing +from enum import Enum from itertools import chain from pathlib import Path, PurePosixPath from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Type, Union @@ -172,19 +173,18 @@ def deferred(self, handler: Callable, event_id: int = 1) -> "DeferredEvent": return self().deferred(handler=handler, event_id=event_id) -@dataclasses.dataclass -class Relation(_DCBase): - endpoint: str - remote_app_name: str = "remote" - remote_unit_ids: List[int] = dataclasses.field(default_factory=list) +class RelationType(str, Enum): + subordinate = "subordinate" + regular = "regular" + peer = "peer" - # local limit - limit: int = 1 - # scale of the remote application; number of units, leader ID? - # TODO figure out if this is relevant - scale: int = 1 - leader_id: int = 0 +@dataclasses.dataclass +class RelationBase(_DCBase): + if typing.TYPE_CHECKING: + __type__: RelationType + + endpoint: str # we can derive this from the charm's metadata interface: str = None @@ -193,11 +193,12 @@ class Relation(_DCBase): relation_id: int = -1 local_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) - remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) local_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) - remote_units_data: Dict[int, Dict[str, str]] = dataclasses.field( - default_factory=dict - ) + + @property + def __databags__(self): + yield self.local_app_data + yield self.local_unit_data def __post_init__(self): global _RELATION_IDS_CTR @@ -205,35 +206,20 @@ def __post_init__(self): _RELATION_IDS_CTR += 1 self.relation_id = _RELATION_IDS_CTR - if self.remote_unit_ids and self.remote_units_data: - if not set(self.remote_unit_ids) == set(self.remote_units_data): - raise StateValidationError( - f"{self.remote_unit_ids} should include any and all IDs from {self.remote_units_data}" - ) - elif self.remote_unit_ids: - self.remote_units_data = {x: {} for x in self.remote_unit_ids} - elif self.remote_units_data: - self.remote_unit_ids = [x for x in self.remote_units_data] - else: - self.remote_unit_ids = [0] - self.remote_units_data = {0: {}} - - for databag in ( - self.local_unit_data, - self.local_app_data, - self.remote_app_data, - *self.remote_units_data.values(), - ): - if not isinstance(databag, dict): + for databag in self.__databags__: + self._validate_databag(databag) + + def _validate_databag(self, databag: dict): + if not isinstance(databag, dict): + raise StateValidationError( + f"all databags should be dicts, not {type(databag)}" + ) + for k, v in databag.items(): + if not isinstance(v, str): raise StateValidationError( - f"all databags should be dicts, not {type(databag)}" + f"all databags should be Dict[str,str]; " + f"found a value of type {type(v)}" ) - for k, v in databag.items(): - if not isinstance(v, str): - raise StateValidationError( - f"all databags should be Dict[str,str]; " - f"found a value of type {type(v)}" - ) @property def changed_event(self) -> "Event": @@ -271,6 +257,121 @@ def broken_event(self) -> "Event": ) +def unify_ids_and_remote_units_data(ids: List[int], data: Dict[int, Any]): + """Unify and validate a list of unit IDs and a mapping from said ids to databag contents. + + This allows the user to pass equivalently: + ids = [] + data = {1: {}} + + or + + ids = [1] + data = {} + + or + + ids = [1] + data = {1: {}} + + but catch the inconsistent: + + ids = [1] + data = {2: {}} + + or + + ids = [2] + data = {1: {}} + """ + if ids and data: + if not set(ids) == set(data): + raise StateValidationError( + f"{ids} should include any and all IDs from {data}" + ) + elif ids: + data = {x: {} for x in ids} + elif data: + ids = [x for x in data] + else: + ids = [0] + data = {0: {}} + return ids, data + + +@dataclasses.dataclass +class Relation(RelationBase): + __type__ = RelationType.regular + remote_app_name: str = "remote" + remote_unit_ids: List[int] = dataclasses.field(default_factory=list) + + # local limit + limit: int = 1 + + remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) + remote_units_data: Dict[int, Dict[str, str]] = dataclasses.field( + default_factory=dict + ) + + @property + def __databags__(self): + yield self.local_app_data + yield self.local_unit_data + yield self.remote_app_data + yield from self.remote_units_data.values() + + def __post_init__(self): + super().__post_init__() + self.remote_unit_ids, self.remote_units_data = unify_ids_and_remote_units_data( + self.remote_unit_ids, self.remote_units_data + ) + + +@dataclasses.dataclass +class SubordinateRelation(RelationBase): + __type__ = RelationType.subordinate + remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) + remote_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) + + # app name and ID of the primary that *this unit* is attached to. + primary_app_name: str = "remote" + primary_id: int = 0 + + # IDs of the peers. Consistency checks will validate that *this unit*'s ID is not in here. + peers_ids: List[int] = dataclasses.field(default_factory=list) + + @property + def __databags__(self): + yield self.local_app_data + yield self.local_unit_data + yield self.remote_app_data + yield self.remote_unit_data + + @property + def primary_name(self) -> str: + return f"{self.primary_app_name}/{self.primary_id}" + + +@dataclasses.dataclass +class PeerRelation(RelationBase): + __type__ = RelationType.peer + peers_data: Dict[int, Dict[str, str]] = dataclasses.field(default_factory=dict) + + # IDs of the peers. Consistency checks will validate that *this unit*'s ID is not in here. + peers_ids: List[int] = dataclasses.field(default_factory=list) + + @property + def __databags__(self): + yield self.local_app_data + yield self.local_unit_data + yield from self.peers_data.values() + + def __post_init__(self): + self.peers_ids, self.peers_data = unify_ids_and_remote_units_data( + self.peers_ids, self.peers_data + ) + + def _random_model_name(): import random import string diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 41541e34..1d659703 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -6,7 +6,14 @@ from ops.framework import EventBase, Framework from scenario.runtime import InconsistentScenarioError -from scenario.state import Relation, State, StateValidationError +from scenario.state import ( + PeerRelation, + Relation, + RelationType, + State, + StateValidationError, + SubordinateRelation, +) @pytest.fixture(scope="function") @@ -222,3 +229,15 @@ def test_relation_unit_data_bad_types(mycharm, data): def test_relation_app_data_bad_types(mycharm, data): with pytest.raises(StateValidationError): relation = Relation(endpoint="foo", interface="foo", local_app_data={"a": data}) + + +@pytest.mark.parametrize( + "relation, expected_type", + ( + (Relation("a"), RelationType.regular), + (PeerRelation("b"), RelationType.peer), + (SubordinateRelation("b"), RelationType.subordinate), + ), +) +def test_relation_type(relation, expected_type): + assert relation.__type__ == expected_type From 80ba34386687265326b02db6eda7f54159710243 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Thu, 30 Mar 2023 15:03:31 +0200 Subject: [PATCH 09/19] fixed emission for peer/sub --- scenario/mocking.py | 1 + scenario/runtime.py | 88 ++++++++++++++++++++++---------- scenario/state.py | 9 ++-- tests/test_e2e/test_relations.py | 26 ++++++++++ 4 files changed, 92 insertions(+), 32 deletions(-) diff --git a/scenario/mocking.py b/scenario/mocking.py index 04f0e9b0..db9370b2 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -104,6 +104,7 @@ def _generate_secret_id(): return f"secret:{id}" def relation_get(self, rel_id, obj_name, app): + # fixme: this WILL definitely bork with peer and sub relation types. relation = self._get_relation_by_id(rel_id) if app and obj_name == self.app_name: return relation.local_app_data diff --git a/scenario/runtime.py b/scenario/runtime.py index b8668b59..964e216e 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -29,7 +29,14 @@ if TYPE_CHECKING: from ops.testing import CharmType - from scenario.state import DeferredEvent, Event, State, StoredState, _CharmSpec + from scenario.state import ( + AnyRelation, + DeferredEvent, + Event, + State, + StoredState, + _CharmSpec, + ) _CT = TypeVar("_CT", bound=Type[CharmType]) @@ -148,6 +155,7 @@ def __init__( if not app_name: raise ValueError('invalid metadata: mandatory "name" field is missing.') + self._app_name = app_name # todo: consider parametrizing unit-id? cfr https://github.com/canonical/ops-scenario/issues/11 self._unit_name = f"{app_name}/0" @@ -179,40 +187,64 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): # todo consider setting pwd, (python)path } - if relation := event.relation: + relation: "AnyRelation" + from scenario.state import RelationType # avoid cyclic import # todo refactor + + if event._is_relation_event and (relation := event.relation): # noqa + if relation.__type__ == RelationType.regular: + remote_app_name = relation.remote_app_name + elif relation.__type__ == RelationType.peer: + remote_app_name = self._app_name + elif relation.__type__ == RelationType.subordinate: + remote_app_name = relation.primary_app_name + else: + raise TypeError( + f"Invalid relation type for {relation}: {relation.__type__}" + ) + env.update( { "JUJU_RELATION": relation.endpoint, "JUJU_RELATION_ID": str(relation.relation_id), - "JUJU_REMOTE_APP": relation.remote_app_name, + "JUJU_REMOTE_APP": remote_app_name, } ) - if event._is_relation_event: # noqa - remote_unit_id = event.relation_remote_unit_id - if ( - remote_unit_id is None - ): # don't check truthiness because it could be int(0) - if len(relation.remote_unit_ids) == 1: - remote_unit_id = relation.remote_unit_ids[0] - logger.info( - "there's only one remote unit, so we set JUJU_REMOTE_UNIT to it, " - "but you probably should be parametrizing the event with `remote_unit` " - "to be explicit." - ) - else: - logger.warning( - "unable to determine remote unit ID; which means JUJU_REMOTE_UNIT will " - "be unset and you might get error if charm code attempts to access " - "`event.unit` in event handlers. \n" - "If that is the case, pass `remote_unit` to the Event constructor." - ) - - if remote_unit_id is not None: - remote_unit = f"{relation.remote_app_name}/{remote_unit_id}" - env["JUJU_REMOTE_UNIT"] = remote_unit - if event.name.endswith("_relation_departed"): - env["JUJU_DEPARTING_UNIT"] = remote_unit + remote_unit_id = event.relation_remote_unit_id + if ( + remote_unit_id is None + ): # don't check truthiness because it could be int(0) + if relation.__type__ == RelationType.regular: + remote_unit_ids = relation.remote_unit_ids + elif relation.__type__ == RelationType.peer: + remote_unit_ids = relation.peers_ids + elif relation.__type__ == RelationType.subordinate: + remote_unit_ids = [relation.primary_id] + else: + raise TypeError( + f"Invalid relation type for {relation}: {relation.__type__}" + ) + + if len(remote_unit_ids) == 1: + remote_unit_id = remote_unit_ids[0] + logger.info( + "there's only one remote unit, so we set JUJU_REMOTE_UNIT to it, " + "but you probably should be parametrizing the event with `remote_unit` " + "to be explicit." + ) + else: + logger.warning( + "unable to determine remote unit ID; which means JUJU_REMOTE_UNIT will " + "be unset and you might get error if charm code attempts to access " + "`event.unit` in event handlers. \n" + "If that is the case, pass `remote_unit` to the Event constructor." + ) + + if remote_unit_id is not None: + remote_unit = f"{remote_app_name}/{remote_unit_id}" + env["JUJU_REMOTE_UNIT"] = remote_unit + if event.name.endswith("_relation_departed"): + env["JUJU_DEPARTING_UNIT"] = remote_unit if container := event.container: env.update({"JUJU_WORKLOAD_NAME": container.name}) diff --git a/scenario/state.py b/scenario/state.py index 20b384ec..59c1c792 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -30,6 +30,8 @@ from ops.testing import CharmType PathLike = Union[str, Path] + AnyRelation = Union["Relation", "PeerRelation", "SubordinateRelation"] + logger = scenario_logger.getChild("state") @@ -330,6 +332,8 @@ def __post_init__(self): @dataclasses.dataclass class SubordinateRelation(RelationBase): __type__ = RelationType.subordinate + + # todo: consider renaming them to primary_*_data remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) remote_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) @@ -337,9 +341,6 @@ class SubordinateRelation(RelationBase): primary_app_name: str = "remote" primary_id: int = 0 - # IDs of the peers. Consistency checks will validate that *this unit*'s ID is not in here. - peers_ids: List[int] = dataclasses.field(default_factory=list) - @property def __databags__(self): yield self.local_app_data @@ -708,7 +709,7 @@ class State(_DCBase): config: Dict[str, Union[str, int, float, bool]] = dataclasses.field( default_factory=dict ) - relations: List[Relation] = dataclasses.field(default_factory=list) + relations: List[RelationBase] = dataclasses.field(default_factory=list) networks: List[Network] = dataclasses.field(default_factory=list) containers: List[Container] = dataclasses.field(default_factory=list) status: Status = dataclasses.field(default_factory=Status) diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 1d659703..66633d2e 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -241,3 +241,29 @@ def test_relation_app_data_bad_types(mycharm, data): ) def test_relation_type(relation, expected_type): assert relation.__type__ == expected_type + + +@pytest.mark.parametrize( + "evt_name", + ("changed", "broken", "departed", "joined", "created"), +) +@pytest.mark.parametrize( + "relation", + (Relation("a"), PeerRelation("b"), SubordinateRelation("b")), +) +def test_relation_event_trigger(relation, evt_name, mycharm): + meta = { + "name": "mycharm", + "requires": {"a": {"interface": "i1"}}, + "provides": { + "c": { + "interface": "i3", + # this is a subordinate relation. + "scope": "container", + } + }, + "peers": {"b": {"interface": "i2"}}, + } + state = State(relations=[relation]).trigger( + getattr(relation, evt_name + "_event"), mycharm, meta=meta + ) From 007facdda55203285b15367c393ddb3e1ad35181 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Thu, 30 Mar 2023 15:06:06 +0200 Subject: [PATCH 10/19] fixed relation-list for subs --- scenario/mocking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scenario/mocking.py b/scenario/mocking.py index db9370b2..e2e27763 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -130,7 +130,7 @@ def relation_ids(self, relation_name): if rel.endpoint == relation_name ] - def relation_list(self, relation_id: int): + def relation_list(self, relation_id: int) -> Tuple[str]: relation = self._get_relation_by_id(relation_id) relation_type = getattr(relation, "__type__", "") if relation_type == "regular": @@ -142,7 +142,7 @@ def relation_list(self, relation_id: int): return tuple(f"{self.app_name}/{unit_id}" for unit_id in relation.peers_ids) elif relation_type == "subordinate": - return tuple(f"{relation.primary_name}") + return f"{relation.primary_name}", else: raise RuntimeError( f"Invalid relation type: {relation_type}; should be one of " From ca3fde95feed601cd8986975cfd217afce13a1f5 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Thu, 30 Mar 2023 15:37:57 +0200 Subject: [PATCH 11/19] more tests for subs --- scenario/consistency_checker.py | 36 +++++++++++++++++++++++++++----- scenario/mocking.py | 19 ++++++++++++++--- tests/test_e2e/test_relations.py | 30 ++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 520fcf43..2d9a3e8b 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -1,4 +1,6 @@ import os +from collections import Counter +from itertools import chain from typing import TYPE_CHECKING, Iterable, NamedTuple, Tuple from scenario.runtime import InconsistentScenarioError @@ -184,11 +186,35 @@ def check_relation_consistency( *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: errors = [] - for relation in state.relations: - if isinstance(relation, SubordinateRelation): - # todo: verify that this unit's id is not in: - # relation.remote_unit_id - pass + # check endpoint unicity + seen_endpoints = set() + for rel in chain( + charm_spec.meta.get("requires", ()), + charm_spec.meta.get("provides", ()), + charm_spec.meta.get("peers", ()), + ): + if rel in seen_endpoints: + errors.append("duplicate endpoint name in metadata.") + break + seen_endpoints.add(rel) + + subs = list(filter(lambda x: isinstance(x, SubordinateRelation), state.relations)) + + # check subordinate relation consistency + seen_sub_primaries = set() + sub: SubordinateRelation + for sub in subs: + sig = (sub.primary_name, sub.endpoint) + if sig in seen_sub_primaries: + errors.append( + "cannot have multiple subordinate relations on the same endpoint with the same primary." + ) + break + seen_sub_primaries.add(sig) + + for sub in subs: + # todo: verify that *this unit*'s id is not in {relation.remote_unit_id} + pass return Results(errors, []) diff --git a/scenario/mocking.py b/scenario/mocking.py index e2e27763..bfb89e95 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -112,9 +112,21 @@ def relation_get(self, rel_id, obj_name, app): return relation.remote_app_data elif obj_name == self.unit_name: return relation.local_unit_data + + unit_id = int(obj_name.split("/")[-1]) + + relation_type = getattr(relation, "__type__", "") + # todo replace with enum value once cyclic import is fixed + if relation_type == "regular": + return relation.remote_units_data[unit_id] + elif relation_type == "peer": + return relation.peers_data[unit_id] + elif relation_type == "subordinate": + return relation.remote_unit_data else: - unit_id = obj_name.split("/")[-1] - return relation.remote_units_data[int(unit_id)] + raise TypeError( + f"Invalid relation type for {relation}: {relation.__type__}" + ) def is_leader(self): return self._state.leader @@ -132,6 +144,7 @@ def relation_ids(self, relation_name): def relation_list(self, relation_id: int) -> Tuple[str]: relation = self._get_relation_by_id(relation_id) + # todo replace with enum value once cyclic import is fixed relation_type = getattr(relation, "__type__", "") if relation_type == "regular": return tuple( @@ -142,7 +155,7 @@ def relation_list(self, relation_id: int) -> Tuple[str]: return tuple(f"{self.app_name}/{unit_id}" for unit_id in relation.peers_ids) elif relation_type == "subordinate": - return f"{relation.primary_name}", + return (f"{relation.primary_name}",) else: raise RuntimeError( f"Invalid relation type: {relation_type}; should be one of " diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 66633d2e..515ff013 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -267,3 +267,33 @@ def test_relation_event_trigger(relation, evt_name, mycharm): state = State(relations=[relation]).trigger( getattr(relation, evt_name + "_event"), mycharm, meta=meta ) + + +def test_trigger_sub_relation(mycharm): + meta = { + "name": "mycharm", + "provides": { + "foo": { + "interface": "bar", + # this is a subordinate relation. + "scope": "container", + } + }, + } + + sub1 = SubordinateRelation( + "foo", remote_unit_data={"1": "2"}, primary_app_name="primary1" + ) + sub2 = SubordinateRelation( + "foo", remote_unit_data={"3": "4"}, primary_app_name="primary2" + ) + + def post_event(charm: CharmBase): + b_relations = charm.model.relations["foo"] + assert len(b_relations) == 2 + for relation in b_relations: + assert len(relation.units) == 1 + + State(relations=[sub1, sub2]).trigger( + "update-status", mycharm, meta=meta, post_event=post_event + ) From 45168a78f4192a50619555fdccc33a8fc8745dc6 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Fri, 31 Mar 2023 09:00:09 +0200 Subject: [PATCH 12/19] wip: commented out sub rel consistency check --- scenario/consistency_checker.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 2d9a3e8b..de6fc135 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -201,16 +201,17 @@ def check_relation_consistency( subs = list(filter(lambda x: isinstance(x, SubordinateRelation), state.relations)) # check subordinate relation consistency - seen_sub_primaries = set() - sub: SubordinateRelation - for sub in subs: - sig = (sub.primary_name, sub.endpoint) - if sig in seen_sub_primaries: - errors.append( - "cannot have multiple subordinate relations on the same endpoint with the same primary." - ) - break - seen_sub_primaries.add(sig) + # todo determine what this rule should be + # seen_sub_primaries = {} + # sub: SubordinateRelation + # for sub in subs: + # if seen_primary := seen_sub_primaries.get(sub.endpoint): + # if sub.primary_name != seen_primary.primary_name: + # errors.append( + # "cannot have multiple subordinate relations on the same " + # "endpoint with different primaries." + # ) + # break for sub in subs: # todo: verify that *this unit*'s id is not in {relation.remote_unit_id} From 32cb3a251369385d0a741e4e5349d4773f83c9cc Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Fri, 31 Mar 2023 10:38:16 +0200 Subject: [PATCH 13/19] better consistency checks for subs and peers --- scenario/consistency_checker.py | 75 +++++++++++++++++++++---------- scenario/state.py | 11 ++++- tests/test_consistency_checker.py | 47 ++++++++++++++++++- 3 files changed, 107 insertions(+), 26 deletions(-) diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index de6fc135..92309489 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -5,7 +5,7 @@ from scenario.runtime import InconsistentScenarioError from scenario.runtime import logger as scenario_logger -from scenario.state import SubordinateRelation, _CharmSpec, normalize_name +from scenario.state import SubordinateRelation, _CharmSpec, normalize_name, RelationType if TYPE_CHECKING: from scenario.state import Event, State @@ -21,10 +21,10 @@ class Results(NamedTuple): def check_consistency( - state: "State", - event: "Event", - charm_spec: "_CharmSpec", - juju_version: str, + state: "State", + event: "Event", + charm_spec: "_CharmSpec", + juju_version: str, ): """Validate the combination of a state, an event, a charm spec, and a juju version. @@ -49,11 +49,11 @@ def check_consistency( warnings = [] for check in ( - check_containers_consistency, - check_config_consistency, - check_event_consistency, - check_secrets_consistency, - check_relation_consistency, + check_containers_consistency, + check_config_consistency, + check_event_consistency, + check_secrets_consistency, + check_relation_consistency, ): results = check( state=state, event=event, charm_spec=charm_spec, juju_version=juju_version @@ -75,7 +75,7 @@ def check_consistency( def check_event_consistency( - *, event: "Event", charm_spec: "_CharmSpec", **_kwargs + *, event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: """Check the internal consistency of the Event data structure. @@ -122,7 +122,7 @@ def check_event_consistency( def check_config_consistency( - *, state: "State", charm_spec: "_CharmSpec", **_kwargs + *, state: "State", charm_spec: "_CharmSpec", **_kwargs ) -> Results: """Check the consistency of the state.config with the charm_spec.config (config.yaml).""" state_config = state.config @@ -162,7 +162,7 @@ def check_config_consistency( def check_secrets_consistency( - *, event: "Event", state: "State", juju_version: Tuple[int, ...], **_kwargs + *, event: "Event", state: "State", juju_version: Tuple[int, ...], **_kwargs ) -> Results: """Check the consistency of Secret-related stuff.""" errors = [] @@ -183,20 +183,49 @@ def check_secrets_consistency( def check_relation_consistency( - *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs + *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: errors = [] - # check endpoint unicity + nonpeer_relations_meta = list(chain(charm_spec.meta.get("requires", {}).items(), + charm_spec.meta.get("provides", {}).items())) + peer_relations_meta = charm_spec.meta.get("peers", {}).items() + all_relations_meta = list(chain(nonpeer_relations_meta, + peer_relations_meta)) + + def _get_relations(r): + try: + return state.get_relations(r) + except ValueError: + return () + + # check relation types + for endpoint, _ in peer_relations_meta: + for relation in _get_relations(endpoint): + if relation.__type__ is not RelationType.peer: + errors.append(f"endpoint {endpoint} is a peer relation; " + f"expecting relation to be of type PeerRelation, gotten {type(relation)}") + + for endpoint, relation_meta in all_relations_meta: + expected_sub = relation_meta.get('scope', '') == 'container' + relations = _get_relations(endpoint) + for relation in relations: + is_sub = relation.__type__ is RelationType.subordinate + if is_sub and not expected_sub: + errors.append(f"endpoint {endpoint} is not a subordinate relation; " + f"expecting relation to be of type Relation, " + f"gotten {type(relation)}") + if expected_sub and not is_sub: + errors.append(f"endpoint {endpoint} is a subordinate relation; " + f"expecting relation to be of type SubordinateRelation, " + f"gotten {type(relation)}") + + # check for duplicate endpoint names seen_endpoints = set() - for rel in chain( - charm_spec.meta.get("requires", ()), - charm_spec.meta.get("provides", ()), - charm_spec.meta.get("peers", ()), - ): - if rel in seen_endpoints: + for endpoint, relation_meta in all_relations_meta: + if endpoint in seen_endpoints: errors.append("duplicate endpoint name in metadata.") break - seen_endpoints.add(rel) + seen_endpoints.add(endpoint) subs = list(filter(lambda x: isinstance(x, SubordinateRelation), state.relations)) @@ -221,7 +250,7 @@ def check_relation_consistency( def check_containers_consistency( - *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs + *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: """Check the consistency of `state.containers` vs. `charm_spec.meta` (metadata.yaml/containers).""" meta_containers = list(charm_spec.meta.get("containers", {})) diff --git a/scenario/state.py b/scenario/state.py index 59c1c792..9006413d 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -709,7 +709,7 @@ class State(_DCBase): config: Dict[str, Union[str, int, float, bool]] = dataclasses.field( default_factory=dict ) - relations: List[RelationBase] = dataclasses.field(default_factory=list) + relations: List["AnyRelation"] = dataclasses.field(default_factory=list) networks: List[Network] = dataclasses.field(default_factory=list) containers: List[Container] = dataclasses.field(default_factory=list) status: Status = dataclasses.field(default_factory=Status) @@ -755,7 +755,14 @@ def get_container(self, container: Union[str, Container]) -> Container: except StopIteration as e: raise ValueError(f"container: {name}") from e - # FIXME: not a great way to obtain a delta, but is "complete" todo figure out a better way. + def get_relations(self, endpoint: str) -> Tuple["AnyRelation"]: + """Get relation from this State, based on an input relation or its endpoint name.""" + try: + return tuple(filter(lambda c: c.endpoint == endpoint, self.relations)) + except StopIteration as e: + raise ValueError(f"relation: {endpoint}") from e + + # FIXME: not a great way to obtain a delta, but is "complete". todo figure out a better way. def jsonpatch_delta(self, other: "State"): try: import jsonpatch diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 6e82119a..67881e01 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -10,7 +10,7 @@ Relation, Secret, State, - _CharmSpec, + _CharmSpec, PeerRelation, SubordinateRelation, ) @@ -154,3 +154,48 @@ def test_secrets_jujuv_bad(good_v): _CharmSpec(MyCharm, {}), good_v, ) + + +def test_peer_relation_consistency(): + assert_inconsistent( + State(relations=[Relation('foo')]), + Event("bar"), + _CharmSpec(MyCharm, { + 'peers': {'foo': {'interface': 'bar'}} + }), + ) + assert_consistent( + State(relations=[PeerRelation('foo')]), + Event("bar"), + _CharmSpec(MyCharm, { + 'peers': {'foo': {'interface': 'bar'}} + }), + ) + + +def test_sub_relation_consistency(): + assert_inconsistent( + State(relations=[Relation('foo')]), + Event("bar"), + _CharmSpec(MyCharm, { + 'requires': {'foo': {'interface': 'bar', 'scope': 'container'}} + }), + ) + assert_consistent( + State(relations=[SubordinateRelation('foo')]), + Event("bar"), + _CharmSpec(MyCharm, { + 'requires': {'foo': {'interface': 'bar', 'scope': 'container'}} + }), + ) + + +def test_relation_sub_inconsistent(): + assert_inconsistent( + State(relations=[SubordinateRelation('foo')]), + Event("bar"), + _CharmSpec(MyCharm, { + 'requires': {'foo': {'interface': 'bar'}} + }), + ) + From 0c99ab9e9e7213b874e928b830206e706c8d61a3 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Fri, 31 Mar 2023 12:46:13 +0200 Subject: [PATCH 14/19] dupe container inconsistency --- README.md | 6 +++ scenario/consistency_checker.py | 71 +++++++++++++++++++------------ tests/test_consistency_checker.py | 45 +++++++++++--------- 3 files changed, 73 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 0ab90774..96d6f9be 100644 --- a/README.md +++ b/README.md @@ -211,7 +211,13 @@ def test_relation_data(): # which is very idiomatic and superbly explicit. Noice. ``` +## Relation types +When you use `Relation`, you are specifying a 'normal' relation. But that is not the only type of relation. There are also +peer relations and subordinate relations. While in the background the data model is the same, the data access rules and the consistency constraints on them are very different. For example, it does not make sense for a peer relation to have a different 'remote app' than its 'local app' is, because it's the same application. + + +## Triggering Relation Events If you want to trigger relation events, the easiest way to do so is get a hold of the Relation instance and grab the event from one of its aptly-named properties: ```python diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 92309489..23a938d6 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -5,7 +5,7 @@ from scenario.runtime import InconsistentScenarioError from scenario.runtime import logger as scenario_logger -from scenario.state import SubordinateRelation, _CharmSpec, normalize_name, RelationType +from scenario.state import RelationType, SubordinateRelation, _CharmSpec, normalize_name if TYPE_CHECKING: from scenario.state import Event, State @@ -21,10 +21,10 @@ class Results(NamedTuple): def check_consistency( - state: "State", - event: "Event", - charm_spec: "_CharmSpec", - juju_version: str, + state: "State", + event: "Event", + charm_spec: "_CharmSpec", + juju_version: str, ): """Validate the combination of a state, an event, a charm spec, and a juju version. @@ -49,11 +49,11 @@ def check_consistency( warnings = [] for check in ( - check_containers_consistency, - check_config_consistency, - check_event_consistency, - check_secrets_consistency, - check_relation_consistency, + check_containers_consistency, + check_config_consistency, + check_event_consistency, + check_secrets_consistency, + check_relation_consistency, ): results = check( state=state, event=event, charm_spec=charm_spec, juju_version=juju_version @@ -75,7 +75,7 @@ def check_consistency( def check_event_consistency( - *, event: "Event", charm_spec: "_CharmSpec", **_kwargs + *, event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: """Check the internal consistency of the Event data structure. @@ -122,7 +122,7 @@ def check_event_consistency( def check_config_consistency( - *, state: "State", charm_spec: "_CharmSpec", **_kwargs + *, state: "State", charm_spec: "_CharmSpec", **_kwargs ) -> Results: """Check the consistency of the state.config with the charm_spec.config (config.yaml).""" state_config = state.config @@ -162,7 +162,7 @@ def check_config_consistency( def check_secrets_consistency( - *, event: "Event", state: "State", juju_version: Tuple[int, ...], **_kwargs + *, event: "Event", state: "State", juju_version: Tuple[int, ...], **_kwargs ) -> Results: """Check the consistency of Secret-related stuff.""" errors = [] @@ -183,14 +183,17 @@ def check_secrets_consistency( def check_relation_consistency( - *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs + *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: errors = [] - nonpeer_relations_meta = list(chain(charm_spec.meta.get("requires", {}).items(), - charm_spec.meta.get("provides", {}).items())) + nonpeer_relations_meta = list( + chain( + charm_spec.meta.get("requires", {}).items(), + charm_spec.meta.get("provides", {}).items(), + ) + ) peer_relations_meta = charm_spec.meta.get("peers", {}).items() - all_relations_meta = list(chain(nonpeer_relations_meta, - peer_relations_meta)) + all_relations_meta = list(chain(nonpeer_relations_meta, peer_relations_meta)) def _get_relations(r): try: @@ -202,22 +205,28 @@ def _get_relations(r): for endpoint, _ in peer_relations_meta: for relation in _get_relations(endpoint): if relation.__type__ is not RelationType.peer: - errors.append(f"endpoint {endpoint} is a peer relation; " - f"expecting relation to be of type PeerRelation, gotten {type(relation)}") + errors.append( + f"endpoint {endpoint} is a peer relation; " + f"expecting relation to be of type PeerRelation, gotten {type(relation)}" + ) for endpoint, relation_meta in all_relations_meta: - expected_sub = relation_meta.get('scope', '') == 'container' + expected_sub = relation_meta.get("scope", "") == "container" relations = _get_relations(endpoint) for relation in relations: is_sub = relation.__type__ is RelationType.subordinate if is_sub and not expected_sub: - errors.append(f"endpoint {endpoint} is not a subordinate relation; " - f"expecting relation to be of type Relation, " - f"gotten {type(relation)}") + errors.append( + f"endpoint {endpoint} is not a subordinate relation; " + f"expecting relation to be of type Relation, " + f"gotten {type(relation)}" + ) if expected_sub and not is_sub: - errors.append(f"endpoint {endpoint} is a subordinate relation; " - f"expecting relation to be of type SubordinateRelation, " - f"gotten {type(relation)}") + errors.append( + f"endpoint {endpoint} is a subordinate relation; " + f"expecting relation to be of type SubordinateRelation, " + f"gotten {type(relation)}" + ) # check for duplicate endpoint names seen_endpoints = set() @@ -250,7 +259,7 @@ def _get_relations(r): def check_containers_consistency( - *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs + *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: """Check the consistency of `state.containers` vs. `charm_spec.meta` (metadata.yaml/containers).""" meta_containers = list(charm_spec.meta.get("containers", {})) @@ -279,4 +288,10 @@ def check_containers_consistency( f"some containers declared in the state are not specified in metadata. That's not possible. " f"Missing from metadata: {diff}." ) + + # guard against duplicate container names + names = Counter(state_containers) + if dupes := [n for n in names if names[n] > 1]: + errors.append(f"Duplicate container name(s): {dupes}.") + return Results(errors, []) diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 67881e01..3a8511f9 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -7,10 +7,12 @@ RELATION_EVENTS_SUFFIX, Container, Event, + PeerRelation, Relation, Secret, State, - _CharmSpec, PeerRelation, SubordinateRelation, + SubordinateRelation, + _CharmSpec, ) @@ -158,44 +160,45 @@ def test_secrets_jujuv_bad(good_v): def test_peer_relation_consistency(): assert_inconsistent( - State(relations=[Relation('foo')]), + State(relations=[Relation("foo")]), Event("bar"), - _CharmSpec(MyCharm, { - 'peers': {'foo': {'interface': 'bar'}} - }), + _CharmSpec(MyCharm, {"peers": {"foo": {"interface": "bar"}}}), ) assert_consistent( - State(relations=[PeerRelation('foo')]), + State(relations=[PeerRelation("foo")]), Event("bar"), - _CharmSpec(MyCharm, { - 'peers': {'foo': {'interface': 'bar'}} - }), + _CharmSpec(MyCharm, {"peers": {"foo": {"interface": "bar"}}}), ) def test_sub_relation_consistency(): assert_inconsistent( - State(relations=[Relation('foo')]), + State(relations=[Relation("foo")]), Event("bar"), - _CharmSpec(MyCharm, { - 'requires': {'foo': {'interface': 'bar', 'scope': 'container'}} - }), + _CharmSpec( + MyCharm, {"requires": {"foo": {"interface": "bar", "scope": "container"}}} + ), ) assert_consistent( - State(relations=[SubordinateRelation('foo')]), + State(relations=[SubordinateRelation("foo")]), Event("bar"), - _CharmSpec(MyCharm, { - 'requires': {'foo': {'interface': 'bar', 'scope': 'container'}} - }), + _CharmSpec( + MyCharm, {"requires": {"foo": {"interface": "bar", "scope": "container"}}} + ), ) def test_relation_sub_inconsistent(): assert_inconsistent( - State(relations=[SubordinateRelation('foo')]), + State(relations=[SubordinateRelation("foo")]), Event("bar"), - _CharmSpec(MyCharm, { - 'requires': {'foo': {'interface': 'bar'}} - }), + _CharmSpec(MyCharm, {"requires": {"foo": {"interface": "bar"}}}), ) + +def test_dupe_containers_inconsistent(): + assert_inconsistent( + State(containers=[Container("foo"), Container("foo")]), + Event("bar"), + _CharmSpec(MyCharm, {"containers": {"foo": {}}}), + ) From 8f93d1d17e4194e24b5d14e1d862dcafbebe579c Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Mon, 3 Apr 2023 10:37:28 +0200 Subject: [PATCH 15/19] added todo --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 96d6f9be..80b2e163 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,9 @@ When you use `Relation`, you are specifying a 'normal' relation. But that is not peer relations and subordinate relations. While in the background the data model is the same, the data access rules and the consistency constraints on them are very different. For example, it does not make sense for a peer relation to have a different 'remote app' than its 'local app' is, because it's the same application. +TODO: describe peer/sub API. + + ## Triggering Relation Events If you want to trigger relation events, the easiest way to do so is get a hold of the Relation instance and grab the event from one of its aptly-named properties: From 823e7619c172c4b779cf911e8dc75ed01a79fb47 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Mon, 3 Apr 2023 16:03:42 +0200 Subject: [PATCH 16/19] defaulted remote_unit --- README.md | 4 ++-- scenario/runtime.py | 12 ++++++------ scenario/state.py | 10 +++++----- tests/test_e2e/test_relations.py | 10 +++++----- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 80b2e163..78b83af5 100644 --- a/README.md +++ b/README.md @@ -243,12 +243,12 @@ The reason for this construction is that the event is associated with some relat ### Additional event parameters All relation events have some additional metadata that does not belong in the Relation object, such as, for a relation-joined event, the name of the (remote) unit that is joining the relation. That is what determines what `ops.model.Unit` you get when you get `RelationJoinedEvent().unit` in an event handler. -In order to supply this parameter, you will have to **call** the event object and pass as `remote_unit` the id of the remote unit that the event is about. +In order to supply this parameter, you will have to **call** the event object and pass as `remote_unit_id` the id of the remote unit that the event is about. ```python from scenario import Relation, Event relation = Relation(endpoint="foo", interface="bar") -remote_unit_2_is_joining_event = relation.joined_event(remote_unit=2) +remote_unit_2_is_joining_event = relation.joined_event(remote_unit_id=2) # which is syntactic sugar for: remote_unit_2_is_joining_event = Event('foo-relation-changed', relation=relation, relation_remote_unit_id=2) diff --git a/scenario/runtime.py b/scenario/runtime.py index 964e216e..595f376c 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -168,6 +168,7 @@ def _cleanup_env(env): # os.unsetenv does not work !? del os.environ[key] + def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if event.name.endswith("_action"): # todo: do we need some special metadata, or can we assume action names are always dashes? @@ -229,15 +230,15 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): remote_unit_id = remote_unit_ids[0] logger.info( "there's only one remote unit, so we set JUJU_REMOTE_UNIT to it, " - "but you probably should be parametrizing the event with `remote_unit` " + "but you probably should be parametrizing the event with `remote_unit_id` " "to be explicit." ) else: + remote_unit_id = remote_unit_ids[0] logger.warning( - "unable to determine remote unit ID; which means JUJU_REMOTE_UNIT will " - "be unset and you might get error if charm code attempts to access " - "`event.unit` in event handlers. \n" - "If that is the case, pass `remote_unit` to the Event constructor." + "remote unit ID unset, and multiple remote unit IDs are present; " + "We will pick the first one and hope for the best. You should be passing " + "`remote_unit_id` to the Event constructor." ) if remote_unit_id is not None: @@ -410,7 +411,6 @@ def exec( logger.info(" - Clearing env") self._cleanup_env(env) - assert not os.getenv("JUJU_DEPARTING_UNIT") logger.info(" - closing storage") output_state = self._close_storage(output_state, temporary_charm_root) diff --git a/scenario/state.py b/scenario/state.py index 9006413d..9e6db986 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -169,7 +169,7 @@ def __call__(self, remote_unit: Optional[str] = None) -> "Event": f"{self._category} event constructor." ) - return Event(*self._args, *self._kwargs, relation_remote_unit=remote_unit) + return Event(*self._args, *self._kwargs, relation_remote_unit_id=remote_unit) def deferred(self, handler: Callable, event_id: int = 1) -> "DeferredEvent": return self().deferred(handler=handler, event_id=event_id) @@ -889,13 +889,13 @@ class Event(_DCBase): # - pebble? # - action? - def __call__(self, remote_unit: Optional[int] = None) -> "Event": - if remote_unit and not self._is_relation_event: + def __call__(self, remote_unit_id: Optional[int] = None) -> "Event": + if remote_unit_id and not self._is_relation_event: raise ValueError( - "cannot pass param `remote_unit` to a " + "cannot pass param `remote_unit_id` to a " "non-relation event constructor." ) - return self.replace(relation_remote_unit_id=remote_unit) + return self.replace(relation_remote_unit_id=remote_unit_id) def __post_init__(self): if "-" in self.name: diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 515ff013..36c57432 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -165,7 +165,7 @@ def callback(charm: CharmBase, event): relation, ], ).trigger( - getattr(relation, f"{evt_name}_event")(remote_unit=remote_unit_id), + getattr(relation, f"{evt_name}_event")(remote_unit_id=remote_unit_id), mycharm, meta={ "name": "local", @@ -194,8 +194,8 @@ def test_relation_events_no_attrs(mycharm, evt_name, remote_app_name, caplog): def callback(charm: CharmBase, event): assert event.app # that's always present - assert not event.unit - assert not getattr(event, "departing_unit", False) + assert event.unit + assert (evt_name == 'departed') is bool(getattr(event, "departing_unit", False)) mycharm._call = callback @@ -214,7 +214,7 @@ def callback(charm: CharmBase, event): }, ) - assert "unable to determine remote unit ID" in caplog.text + assert "remote unit ID unset, and multiple remote unit IDs are present" in caplog.text @pytest.mark.parametrize("data", (set(), {}, [], (), 1, 1.0, None, b"")) @@ -249,7 +249,7 @@ def test_relation_type(relation, expected_type): ) @pytest.mark.parametrize( "relation", - (Relation("a"), PeerRelation("b"), SubordinateRelation("b")), + (Relation("a"), PeerRelation("b"), SubordinateRelation("c")), ) def test_relation_event_trigger(relation, evt_name, mycharm): meta = { From b67bb1b320da396f49c946252c82dad82554e35a Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Tue, 4 Apr 2023 11:19:28 +0200 Subject: [PATCH 17/19] databags dedundered --- README.md | 7 +++++-- scenario/consistency_checker.py | 35 +++++++------------------------- scenario/mocking.py | 1 - scenario/runtime.py | 1 - scenario/state.py | 18 +++++++++++----- tests/test_e2e/test_relations.py | 6 ++++-- 6 files changed, 29 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 78b83af5..dbe44e9b 100644 --- a/README.md +++ b/README.md @@ -213,8 +213,8 @@ def test_relation_data(): ``` ## Relation types -When you use `Relation`, you are specifying a 'normal' relation. But that is not the only type of relation. There are also -peer relations and subordinate relations. While in the background the data model is the same, the data access rules and the consistency constraints on them are very different. For example, it does not make sense for a peer relation to have a different 'remote app' than its 'local app' is, because it's the same application. +When you use `Relation`, you are specifying a regular (conventional) relation. But that is not the only type of relation. There are also +peer relations and subordinate relations. While in the background the data model is the same, the data access rules and the consistency constraints on them are very different. For example, it does not make sense for a peer relation to have a different 'remote app' than its 'local app', because it's the same application. TODO: describe peer/sub API. @@ -244,6 +244,9 @@ The reason for this construction is that the event is associated with some relat All relation events have some additional metadata that does not belong in the Relation object, such as, for a relation-joined event, the name of the (remote) unit that is joining the relation. That is what determines what `ops.model.Unit` you get when you get `RelationJoinedEvent().unit` in an event handler. In order to supply this parameter, you will have to **call** the event object and pass as `remote_unit_id` the id of the remote unit that the event is about. +The reason that this parameter is not supplied to `Relation` but to relation events, is that the relation already ties 'this app' to some 'remote app' (cfr. the `Relation.remote_app_name` attr), but not to a specific unit. What remote unit this event is about is not a `State` concern but an `Event` one. + +The `remote_unit_id` will default to the first ID found in the relation's `remote_unit_ids`, but if the test you are writing is close to that domain, you should probably override it and pass it manually. ```python from scenario import Relation, Event diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 23a938d6..c5b403fc 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -186,11 +186,9 @@ def check_relation_consistency( *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: errors = [] - nonpeer_relations_meta = list( - chain( - charm_spec.meta.get("requires", {}).items(), - charm_spec.meta.get("provides", {}).items(), - ) + nonpeer_relations_meta = chain( + charm_spec.meta.get("requires", {}).items(), + charm_spec.meta.get("provides", {}).items(), ) peer_relations_meta = charm_spec.meta.get("peers", {}).items() all_relations_meta = list(chain(nonpeer_relations_meta, peer_relations_meta)) @@ -207,7 +205,7 @@ def _get_relations(r): if relation.__type__ is not RelationType.peer: errors.append( f"endpoint {endpoint} is a peer relation; " - f"expecting relation to be of type PeerRelation, gotten {type(relation)}" + f"expecting relation to be of type PeerRelation, got {type(relation)}" ) for endpoint, relation_meta in all_relations_meta: @@ -219,13 +217,13 @@ def _get_relations(r): errors.append( f"endpoint {endpoint} is not a subordinate relation; " f"expecting relation to be of type Relation, " - f"gotten {type(relation)}" + f"got {type(relation)}" ) if expected_sub and not is_sub: errors.append( - f"endpoint {endpoint} is a subordinate relation; " + f"endpoint {endpoint} is not a subordinate relation; " f"expecting relation to be of type SubordinateRelation, " - f"gotten {type(relation)}" + f"got {type(relation)}" ) # check for duplicate endpoint names @@ -236,25 +234,6 @@ def _get_relations(r): break seen_endpoints.add(endpoint) - subs = list(filter(lambda x: isinstance(x, SubordinateRelation), state.relations)) - - # check subordinate relation consistency - # todo determine what this rule should be - # seen_sub_primaries = {} - # sub: SubordinateRelation - # for sub in subs: - # if seen_primary := seen_sub_primaries.get(sub.endpoint): - # if sub.primary_name != seen_primary.primary_name: - # errors.append( - # "cannot have multiple subordinate relations on the same " - # "endpoint with different primaries." - # ) - # break - - for sub in subs: - # todo: verify that *this unit*'s id is not in {relation.remote_unit_id} - pass - return Results(errors, []) diff --git a/scenario/mocking.py b/scenario/mocking.py index bfb89e95..b3ae87fd 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -104,7 +104,6 @@ def _generate_secret_id(): return f"secret:{id}" def relation_get(self, rel_id, obj_name, app): - # fixme: this WILL definitely bork with peer and sub relation types. relation = self._get_relation_by_id(rel_id) if app and obj_name == self.app_name: return relation.local_app_data diff --git a/scenario/runtime.py b/scenario/runtime.py index 595f376c..0219781a 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -168,7 +168,6 @@ def _cleanup_env(env): # os.unsetenv does not work !? del os.environ[key] - def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if event.name.endswith("_action"): # todo: do we need some special metadata, or can we assume action names are always dashes? diff --git a/scenario/state.py b/scenario/state.py index 9e6db986..f09925e6 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -198,7 +198,8 @@ class RelationBase(_DCBase): local_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) @property - def __databags__(self): + def _databags(self): + """Yield all databags in this relation.""" yield self.local_app_data yield self.local_unit_data @@ -206,9 +207,13 @@ def __post_init__(self): global _RELATION_IDS_CTR if self.relation_id == -1: _RELATION_IDS_CTR += 1 + logger.info( + f"relation ID unset; automatically assigning {_RELATION_IDS_CTR}. " + f"If there are problems, pass one manually." + ) self.relation_id = _RELATION_IDS_CTR - for databag in self.__databags__: + for databag in self._databags: self._validate_databag(databag) def _validate_databag(self, databag: dict): @@ -316,7 +321,8 @@ class Relation(RelationBase): ) @property - def __databags__(self): + def _databags(self): + """Yield all databags in this relation.""" yield self.local_app_data yield self.local_unit_data yield self.remote_app_data @@ -342,7 +348,8 @@ class SubordinateRelation(RelationBase): primary_id: int = 0 @property - def __databags__(self): + def _databags(self): + """Yield all databags in this relation.""" yield self.local_app_data yield self.local_unit_data yield self.remote_app_data @@ -362,7 +369,8 @@ class PeerRelation(RelationBase): peers_ids: List[int] = dataclasses.field(default_factory=list) @property - def __databags__(self): + def _databags(self): + """Yield all databags in this relation.""" yield self.local_app_data yield self.local_unit_data yield from self.peers_data.values() diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 36c57432..cd44b28b 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -195,7 +195,7 @@ def test_relation_events_no_attrs(mycharm, evt_name, remote_app_name, caplog): def callback(charm: CharmBase, event): assert event.app # that's always present assert event.unit - assert (evt_name == 'departed') is bool(getattr(event, "departing_unit", False)) + assert (evt_name == "departed") is bool(getattr(event, "departing_unit", False)) mycharm._call = callback @@ -214,7 +214,9 @@ def callback(charm: CharmBase, event): }, ) - assert "remote unit ID unset, and multiple remote unit IDs are present" in caplog.text + assert ( + "remote unit ID unset, and multiple remote unit IDs are present" in caplog.text + ) @pytest.mark.parametrize("data", (set(), {}, [], (), 1, 1.0, None, b"")) From 14f3baf50daab59f98c16273aea95bca9edabbb9 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 5 Apr 2023 10:36:42 +0200 Subject: [PATCH 18/19] cyclic imports fixed --- scenario/consistency_checker.py | 6 +-- scenario/fs_mocks.py | 35 +++++++++++++ scenario/mocking.py | 69 ++++--------------------- scenario/ops_main_mock.py | 4 +- scenario/runtime.py | 26 ++-------- scenario/state.py | 89 ++++++++++++++++++++++++-------- tests/test_e2e/test_relations.py | 21 +++----- 7 files changed, 128 insertions(+), 122 deletions(-) create mode 100644 scenario/fs_mocks.py diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index c5b403fc..98ee93ad 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -5,7 +5,7 @@ from scenario.runtime import InconsistentScenarioError from scenario.runtime import logger as scenario_logger -from scenario.state import RelationType, SubordinateRelation, _CharmSpec, normalize_name +from scenario.state import PeerRelation, SubordinateRelation, _CharmSpec, normalize_name if TYPE_CHECKING: from scenario.state import Event, State @@ -202,7 +202,7 @@ def _get_relations(r): # check relation types for endpoint, _ in peer_relations_meta: for relation in _get_relations(endpoint): - if relation.__type__ is not RelationType.peer: + if not isinstance(relation, PeerRelation): errors.append( f"endpoint {endpoint} is a peer relation; " f"expecting relation to be of type PeerRelation, got {type(relation)}" @@ -212,7 +212,7 @@ def _get_relations(r): expected_sub = relation_meta.get("scope", "") == "container" relations = _get_relations(endpoint) for relation in relations: - is_sub = relation.__type__ is RelationType.subordinate + is_sub = isinstance(relation, SubordinateRelation) if is_sub and not expected_sub: errors.append( f"endpoint {endpoint} is not a subordinate relation; " diff --git a/scenario/fs_mocks.py b/scenario/fs_mocks.py new file mode 100644 index 00000000..38548aec --- /dev/null +++ b/scenario/fs_mocks.py @@ -0,0 +1,35 @@ +import pathlib +from typing import Dict + +from ops.testing import _TestingFilesystem, _TestingStorageMount # noqa + + +# todo consider duplicating the filesystem on State.copy() to be able to diff and have true state snapshots +class _MockStorageMount(_TestingStorageMount): + def __init__(self, location: pathlib.PurePosixPath, src: pathlib.Path): + """Creates a new simulated storage mount. + + Args: + location: The path within simulated filesystem at which this storage will be mounted. + src: The temporary on-disk location where the simulated storage will live. + """ + self._src = src + self._location = location + + try: + # for some reason this fails if src exists, even though exists_ok=True. + super().__init__(location=location, src=src) + except FileExistsError: + pass + + +class _MockFileSystem(_TestingFilesystem): + def __init__(self, mounts: Dict[str, _MockStorageMount]): + super().__init__() + self._mounts = mounts + + def add_mount(self, *args, **kwargs): + raise NotImplementedError("Cannot mutate mounts; declare them all in State.") + + def remove_mount(self, *args, **kwargs): + raise NotImplementedError("Cannot mutate mounts; declare them all in State.") diff --git a/scenario/mocking.py b/scenario/mocking.py index b3ae87fd..8bfda1c4 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -2,7 +2,6 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. import datetime -import pathlib import random from io import StringIO from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union @@ -10,9 +9,10 @@ from ops import pebble from ops.model import SecretInfo, SecretRotate, _ModelBackend from ops.pebble import Client, ExecError -from ops.testing import _TestingFilesystem, _TestingPebbleClient, _TestingStorageMount +from ops.testing import _TestingPebbleClient from scenario.logger import logger as scenario_logger +from scenario.state import PeerRelation if TYPE_CHECKING: from scenario.state import Container as ContainerSpec @@ -113,19 +113,7 @@ def relation_get(self, rel_id, obj_name, app): return relation.local_unit_data unit_id = int(obj_name.split("/")[-1]) - - relation_type = getattr(relation, "__type__", "") - # todo replace with enum value once cyclic import is fixed - if relation_type == "regular": - return relation.remote_units_data[unit_id] - elif relation_type == "peer": - return relation.peers_data[unit_id] - elif relation_type == "subordinate": - return relation.remote_unit_data - else: - raise TypeError( - f"Invalid relation type for {relation}: {relation.__type__}" - ) + return relation._get_databag_for_remote(unit_id) # noqa def is_leader(self): return self._state.leader @@ -143,23 +131,13 @@ def relation_ids(self, relation_name): def relation_list(self, relation_id: int) -> Tuple[str]: relation = self._get_relation_by_id(relation_id) - # todo replace with enum value once cyclic import is fixed - relation_type = getattr(relation, "__type__", "") - if relation_type == "regular": - return tuple( - f"{relation.remote_app_name}/{unit_id}" - for unit_id in relation.remote_unit_ids - ) - elif relation_type == "peer": - return tuple(f"{self.app_name}/{unit_id}" for unit_id in relation.peers_ids) - elif relation_type == "subordinate": - return (f"{relation.primary_name}",) - else: - raise RuntimeError( - f"Invalid relation type: {relation_type}; should be one of " - f"scenario.state.RelationType" - ) + if isinstance(relation, PeerRelation): + return tuple(f"{self.app_name}/{unit_id}" for unit_id in relation.peers_ids) + return tuple( + f"{relation._remote_app_name}/{unit_id}" # noqa + for unit_id in relation._remote_unit_ids # noqa + ) def config_get(self): state_config = self._state.config @@ -352,35 +330,6 @@ def planned_units(self, *args, **kwargs): raise NotImplementedError("planned_units") -class _MockStorageMount(_TestingStorageMount): - def __init__(self, location: pathlib.PurePosixPath, src: pathlib.Path): - """Creates a new simulated storage mount. - - Args: - location: The path within simulated filesystem at which this storage will be mounted. - src: The temporary on-disk location where the simulated storage will live. - """ - self._src = src - self._location = location - if ( - not src.exists() - ): # we need to add this guard because the directory might exist already. - src.mkdir(exist_ok=True, parents=True) - - -# todo consider duplicating the filesystem on State.copy() to be able to diff and have true state snapshots -class _MockFileSystem(_TestingFilesystem): - def __init__(self, mounts: Dict[str, _MockStorageMount]): - super().__init__() - self._mounts = mounts - - def add_mount(self, *args, **kwargs): - raise NotImplementedError("Cannot mutate mounts; declare them all in State.") - - def remove_mount(self, *args, **kwargs): - raise NotImplementedError("Cannot mutate mounts; declare them all in State.") - - class _MockPebbleClient(_TestingPebbleClient): def __init__( self, diff --git a/scenario/ops_main_mock.py b/scenario/ops_main_mock.py index e7723ee3..85b34a78 100644 --- a/scenario/ops_main_mock.py +++ b/scenario/ops_main_mock.py @@ -14,7 +14,6 @@ from ops.main import CHARM_STATE_FILE, _Dispatcher, _emit_charm_event, _get_charm_dir from scenario.logger import logger as scenario_logger -from scenario.mocking import _MockModelBackend if TYPE_CHECKING: from ops.testing import CharmType @@ -38,6 +37,9 @@ def main( """Set up the charm and dispatch the observed event.""" charm_class = charm_spec.charm_type charm_dir = _get_charm_dir() + + from scenario.mocking import _MockModelBackend + model_backend = _MockModelBackend( # pyright: reportPrivateUsage=false state=state, event=event, charm_spec=charm_spec ) diff --git a/scenario/runtime.py b/scenario/runtime.py index 0219781a..647035df 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -25,6 +25,7 @@ from scenario.logger import logger as scenario_logger from scenario.ops_main_mock import NoObserverError +from scenario.state import DeferredEvent, PeerRelation, StoredState if TYPE_CHECKING: from ops.testing import CharmType @@ -80,7 +81,6 @@ def _open_db(self) -> Optional[SQLiteStorage]: def get_stored_state(self) -> List["StoredState"]: """Load any StoredState data structures from the db.""" - from scenario.state import StoredState # avoid cyclic import db = self._open_db() @@ -99,7 +99,6 @@ def get_stored_state(self) -> List["StoredState"]: def get_deferred_events(self) -> List["DeferredEvent"]: """Load any DeferredEvent data structures from the db.""" - from scenario.state import DeferredEvent # avoid cyclic import db = self._open_db() @@ -188,20 +187,12 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): } relation: "AnyRelation" - from scenario.state import RelationType # avoid cyclic import # todo refactor if event._is_relation_event and (relation := event.relation): # noqa - if relation.__type__ == RelationType.regular: - remote_app_name = relation.remote_app_name - elif relation.__type__ == RelationType.peer: + if isinstance(relation, PeerRelation): remote_app_name = self._app_name - elif relation.__type__ == RelationType.subordinate: - remote_app_name = relation.primary_app_name else: - raise TypeError( - f"Invalid relation type for {relation}: {relation.__type__}" - ) - + remote_app_name = relation._remote_app_name # noqa env.update( { "JUJU_RELATION": relation.endpoint, @@ -214,16 +205,7 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if ( remote_unit_id is None ): # don't check truthiness because it could be int(0) - if relation.__type__ == RelationType.regular: - remote_unit_ids = relation.remote_unit_ids - elif relation.__type__ == RelationType.peer: - remote_unit_ids = relation.peers_ids - elif relation.__type__ == RelationType.subordinate: - remote_unit_ids = [relation.primary_id] - else: - raise TypeError( - f"Invalid relation type for {relation}: {relation.__type__}" - ) + remote_unit_ids = relation._remote_unit_ids # noqa if len(remote_unit_ids) == 1: remote_unit_id = remote_unit_ids[0] diff --git a/scenario/state.py b/scenario/state.py index f09925e6..4ec81857 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -7,7 +7,6 @@ import inspect import re import typing -from enum import Enum from itertools import chain from pathlib import Path, PurePosixPath from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Type, Union @@ -18,9 +17,8 @@ from ops.charm import CharmEvents from ops.model import SecretRotate, StatusBase +from scenario.fs_mocks import _MockFileSystem, _MockStorageMount from scenario.logger import logger as scenario_logger -from scenario.mocking import _MockFileSystem, _MockStorageMount -from scenario.runtime import trigger as _runtime_trigger if typing.TYPE_CHECKING: try: @@ -32,7 +30,6 @@ PathLike = Union[str, Path] AnyRelation = Union["Relation", "PeerRelation", "SubordinateRelation"] - logger = scenario_logger.getChild("state") ATTACH_ALL_STORAGES = "ATTACH_ALL_STORAGES" @@ -175,17 +172,8 @@ def deferred(self, handler: Callable, event_id: int = 1) -> "DeferredEvent": return self().deferred(handler=handler, event_id=event_id) -class RelationType(str, Enum): - subordinate = "subordinate" - regular = "regular" - peer = "peer" - - @dataclasses.dataclass class RelationBase(_DCBase): - if typing.TYPE_CHECKING: - __type__: RelationType - endpoint: str # we can derive this from the charm's metadata @@ -203,7 +191,27 @@ def _databags(self): yield self.local_app_data yield self.local_unit_data + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + raise NotImplementedError() + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + raise NotImplementedError() + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + raise NotImplementedError() + def __post_init__(self): + if type(self) is RelationBase: + raise RuntimeError( + "RelationBase cannot be instantiated directly; " + "please use Relation, PeerRelation, or SubordinateRelation" + ) + global _RELATION_IDS_CTR if self.relation_id == -1: _RELATION_IDS_CTR += 1 @@ -308,7 +316,6 @@ def unify_ids_and_remote_units_data(ids: List[int], data: Dict[int, Any]): @dataclasses.dataclass class Relation(RelationBase): - __type__ = RelationType.regular remote_app_name: str = "remote" remote_unit_ids: List[int] = dataclasses.field(default_factory=list) @@ -320,6 +327,20 @@ class Relation(RelationBase): default_factory=dict ) + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + return self.remote_app_name + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + return tuple(self.remote_unit_ids) + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + return self.remote_units_data[unit_id] + @property def _databags(self): """Yield all databags in this relation.""" @@ -337,8 +358,6 @@ def __post_init__(self): @dataclasses.dataclass class SubordinateRelation(RelationBase): - __type__ = RelationType.subordinate - # todo: consider renaming them to primary_*_data remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) remote_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) @@ -347,6 +366,20 @@ class SubordinateRelation(RelationBase): primary_app_name: str = "remote" primary_id: int = 0 + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + return self.primary_app_name + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + return (self.primary_id,) + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + return self.remote_unit_data + @property def _databags(self): """Yield all databags in this relation.""" @@ -362,7 +395,6 @@ def primary_name(self) -> str: @dataclasses.dataclass class PeerRelation(RelationBase): - __type__ = RelationType.peer peers_data: Dict[int, Dict[str, str]] = dataclasses.field(default_factory=dict) # IDs of the peers. Consistency checks will validate that *this unit*'s ID is not in here. @@ -375,6 +407,21 @@ def _databags(self): yield self.local_unit_data yield from self.peers_data.values() + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + # surprise! It's myself. + raise ValueError("peer relations don't quite have a remote end.") + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + return tuple(self.peers_ids) + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + return self.peers_data[unit_id] + def __post_init__(self): self.peers_ids, self.peers_data = unify_ids_and_remote_units_data( self.peers_ids, self.peers_data @@ -516,7 +563,7 @@ def services(self) -> Dict[str, pebble.ServiceInfo]: return infos @property - def filesystem(self) -> _MockFileSystem: + def filesystem(self) -> "_MockFileSystem": mounts = { name: _MockStorageMount( src=Path(spec.src), location=PurePosixPath(spec.location) @@ -801,6 +848,8 @@ def trigger( juju_version: str = "3.0", ) -> "State": """Fluent API for trigger. See runtime.trigger's docstring.""" + from scenario.runtime import trigger as _runtime_trigger + return _runtime_trigger( state=self, event=event, @@ -814,8 +863,6 @@ def trigger( juju_version=juju_version, ) - trigger.__doc__ = _runtime_trigger.__doc__ - @dataclasses.dataclass class _CharmSpec(_DCBase): @@ -882,7 +929,7 @@ class Event(_DCBase): kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # if this is a relation event, the relation it refers to - relation: Optional[Relation] = None + relation: Optional["AnyRelation"] = None # and the name of the remote unit this relation event is about relation_remote_unit_id: Optional[int] = None diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index cd44b28b..7d429aa9 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -1,15 +1,13 @@ -import os from typing import Type import pytest from ops.charm import CharmBase, CharmEvents, RelationDepartedEvent from ops.framework import EventBase, Framework -from scenario.runtime import InconsistentScenarioError from scenario.state import ( PeerRelation, Relation, - RelationType, + RelationBase, State, StateValidationError, SubordinateRelation, @@ -233,18 +231,6 @@ def test_relation_app_data_bad_types(mycharm, data): relation = Relation(endpoint="foo", interface="foo", local_app_data={"a": data}) -@pytest.mark.parametrize( - "relation, expected_type", - ( - (Relation("a"), RelationType.regular), - (PeerRelation("b"), RelationType.peer), - (SubordinateRelation("b"), RelationType.subordinate), - ), -) -def test_relation_type(relation, expected_type): - assert relation.__type__ == expected_type - - @pytest.mark.parametrize( "evt_name", ("changed", "broken", "departed", "joined", "created"), @@ -299,3 +285,8 @@ def post_event(charm: CharmBase): State(relations=[sub1, sub2]).trigger( "update-status", mycharm, meta=meta, post_event=post_event ) + + +def test_cannot_instantiate_relationbase(): + with pytest.raises(RuntimeError): + RelationBase("") From cb9f43217f1aa69a045d6383abe97cd31b49b3b9 Mon Sep 17 00:00:00 2001 From: Pietro Pasotti Date: Wed, 5 Apr 2023 11:01:36 +0200 Subject: [PATCH 19/19] docs --- README.md | 58 ++++++++++++++++++++++++++++++++++++++++--- scenario/runtime.py | 8 ++++-- scenario/sequences.py | 2 ++ scenario/state.py | 2 ++ tests/test_runtime.py | 6 +++-- 5 files changed, 69 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index dbe44e9b..c618c9d6 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ def _on_event(self, _event): You can verify that the charm has followed the expected path by checking the **unit status history** like so: ```python +from charm import MyCharm from ops.model import MaintenanceStatus, ActiveStatus, WaitingStatus, UnknownStatus from scenario import State @@ -148,6 +149,7 @@ def test_statuses(): UnknownStatus(), MaintenanceStatus('determining who the ruler is...'), WaitingStatus('checking this is right...'), + ActiveStatus("I am ruled"), ] ``` @@ -155,7 +157,7 @@ Note that the current status is not in the **unit status history**. Also note that, unless you initialize the State with a preexisting status, the first status in the history will always be `unknown`. That is because, so far as scenario is concerned, each event is "the first event this charm has ever seen". -If you want to simulate a situation in which the charm already has seen some event, and is in a status other than Unknown (the default status every charm is born with), you will have to pass the 'initial status' in State. +If you want to simulate a situation in which the charm already has seen some event, and is in a status other than Unknown (the default status every charm is born with), you will have to pass the 'initial status' to State. ```python from ops.model import ActiveStatus @@ -211,13 +213,63 @@ def test_relation_data(): # which is very idiomatic and superbly explicit. Noice. ``` -## Relation types + +The only mandatory argument to `Relation` (and other relation types, see below) is `endpoint`. The `interface` will be derived from the charm's `metadata.yaml`. When fully defaulted, a relation is 'empty'. There are no remote units, the remote application is called `'remote'` and only has a single unit `remote/0`, and nobody has written any data to the databags yet. + +That is typically the state of a relation when the first unit joins it. When you use `Relation`, you are specifying a regular (conventional) relation. But that is not the only type of relation. There are also peer relations and subordinate relations. While in the background the data model is the same, the data access rules and the consistency constraints on them are very different. For example, it does not make sense for a peer relation to have a different 'remote app' than its 'local app', because it's the same application. +### PeerRelation +To declare a peer relation, you should use `scenario.state.PeerRelation`. +The core difference with regular relations is that peer relations do not have a "remote app" (it's this app, in fact). +So unlike `Relation`, a `PeerRelation` does not have `remote_app_name` or `remote_app_data` arguments. Also, it talks in terms of `peers`: +- `Relation.remote_unit_ids` maps to `PeerRelation.peers_ids` +- `Relation.remote_units_data` maps to `PeerRelation.peers_data` + +```python +from scenario.state import PeerRelation + +relation = PeerRelation( + endpoint="peers", + peers_data={1: {}, 2: {}, 42: {'foo': 'bar'}}, +) +``` + +be mindful when using `PeerRelation` not to include **"this unit"**'s ID in `peers_data` or `peers_ids`, as that would be flagged by the Consistency Checker: +```python +from scenario import State, PeerRelation + +State(relations=[ + PeerRelation( + endpoint="peers", + peers_data={1: {}, 2: {}, 42: {'foo': 'bar'}}, + )]).trigger("start", ..., unit_id=1) # invalid: this unit's id cannot be the ID of a peer. + + +``` + +### SubordinateRelation +To declare a subordinate relation, you should use `scenario.state.SubordinateRelation`. +The core difference with regular relations is that subordinate relations always have exactly one remote unit (there is always exactly one primary unit that this unit can see). +So unlike `Relation`, a `SubordinateRelation` does not have a `remote_units_data` argument. Instead, it has a `remote_unit_data` taking a single `Dict[str:str]`, and takes the primary unit ID as a separate argument. +Also, it talks in terms of `primary`: +- `Relation.remote_unit_ids` becomes `SubordinateRelation.primary_id` (a single ID instead of a list of IDs) +- `Relation.remote_units_data` becomes `SubordinateRelation.remote_unit_data` (a single databag instead of a mapping from unit IDs to databags) +- `Relation.remote_app_name` maps to `SubordinateRelation.primary_app_name` -TODO: describe peer/sub API. +```python +from scenario.state import SubordinateRelation + +relation = SubordinateRelation( + endpoint="peers", + remote_unit_data={"foo": "bar"}, + primary_app_name="zookeeper", + primary_id=42 +) +relation.primary_name # "zookeeper/42" +``` ## Triggering Relation Events diff --git a/scenario/runtime.py b/scenario/runtime.py index 647035df..680f945d 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -145,6 +145,7 @@ def __init__( charm_spec: "_CharmSpec", charm_root: Optional["PathLike"] = None, juju_version: str = "3.0.0", + unit_id: int = 0, ): self._charm_spec = charm_spec self._juju_version = juju_version @@ -155,8 +156,8 @@ def __init__( raise ValueError('invalid metadata: mandatory "name" field is missing.') self._app_name = app_name - # todo: consider parametrizing unit-id? cfr https://github.com/canonical/ops-scenario/issues/11 - self._unit_name = f"{app_name}/0" + self._unit_id = unit_id + self._unit_name = f"{app_name}/{unit_id}" @staticmethod def _cleanup_env(env): @@ -412,6 +413,7 @@ def trigger( config: Optional[Dict[str, Any]] = None, charm_root: Optional[Dict["PathLike", "PathLike"]] = None, juju_version: str = "3.0", + unit_id: int = 0, ) -> "State": """Trigger a charm execution with an Event and a State. @@ -433,6 +435,7 @@ def trigger( :arg config: charm config to use. Needs to be a valid config.yaml format (as a python dict). If none is provided, we will search for a ``config.yaml`` file in the charm root. :arg juju_version: Juju agent version to simulate. + :arg unit_id: The ID of the Juju unit that is charm execution is running on. :arg charm_root: virtual charm root the charm will be executed with. If the charm, say, expects a `./src/foo/bar.yaml` file present relative to the execution cwd, you need to use this. E.g.: @@ -464,6 +467,7 @@ def trigger( charm_spec=spec, juju_version=juju_version, charm_root=charm_root, + unit_id=unit_id, ) return runtime.exec( diff --git a/scenario/sequences.py b/scenario/sequences.py index 04044126..fa30b4dc 100644 --- a/scenario/sequences.py +++ b/scenario/sequences.py @@ -96,6 +96,7 @@ def check_builtin_sequences( template_state: State = None, pre_event: Optional[Callable[["CharmType"], None]] = None, post_event: Optional[Callable[["CharmType"], None]] = None, + unit_id: int = 0, ): """Test that all the builtin startup and teardown events can fire without errors. @@ -124,4 +125,5 @@ def check_builtin_sequences( config=config, pre_event=pre_event, post_event=post_event, + unit_id=unit_id, ) diff --git a/scenario/state.py b/scenario/state.py index 4ec81857..9b453d8d 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -846,6 +846,7 @@ def trigger( config: Optional[Dict[str, Any]] = None, charm_root: Optional["PathLike"] = None, juju_version: str = "3.0", + unit_id: int = 0, ) -> "State": """Fluent API for trigger. See runtime.trigger's docstring.""" from scenario.runtime import trigger as _runtime_trigger @@ -861,6 +862,7 @@ def trigger( config=config, charm_root=charm_root, juju_version=juju_version, + unit_id=unit_id, ) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index d6ef9be1..954ec934 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -88,7 +88,8 @@ class MyEvt(EventBase): @pytest.mark.parametrize("app_name", ("foo", "bar-baz", "QuX2")) -def test_unit_name(app_name): +@pytest.mark.parametrize("unit_id", (1, 2, 42)) +def test_unit_name(app_name, unit_id): meta = { "name": app_name, } @@ -100,9 +101,10 @@ def test_unit_name(app_name): my_charm_type, meta=meta, ), + unit_id=unit_id, ) def post_event(charm: CharmBase): - assert charm.unit.name == f"{app_name}/0" + assert charm.unit.name == f"{app_name}/{unit_id}" runtime.exec(state=State(), event=Event("start"), post_event=post_event)