diff --git a/README.md b/README.md index 62dbcf0f..c618c9d6 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ def _on_event(self, _event): You can verify that the charm has followed the expected path by checking the **unit status history** like so: ```python +from charm import MyCharm from ops.model import MaintenanceStatus, ActiveStatus, WaitingStatus, UnknownStatus from scenario import State @@ -148,6 +149,7 @@ def test_statuses(): UnknownStatus(), MaintenanceStatus('determining who the ruler is...'), WaitingStatus('checking this is right...'), + ActiveStatus("I am ruled"), ] ``` @@ -155,7 +157,7 @@ Note that the current status is not in the **unit status history**. Also note that, unless you initialize the State with a preexisting status, the first status in the history will always be `unknown`. That is because, so far as scenario is concerned, each event is "the first event this charm has ever seen". -If you want to simulate a situation in which the charm already has seen some event, and is in a status other than Unknown (the default status every charm is born with), you will have to pass the 'initial status' in State. +If you want to simulate a situation in which the charm already has seen some event, and is in a status other than Unknown (the default status every charm is born with), you will have to pass the 'initial status' to State. ```python from ops.model import ActiveStatus @@ -212,6 +214,102 @@ def test_relation_data(): # which is very idiomatic and superbly explicit. Noice. ``` +The only mandatory argument to `Relation` (and other relation types, see below) is `endpoint`. The `interface` will be derived from the charm's `metadata.yaml`. When fully defaulted, a relation is 'empty'. There are no remote units, the remote application is called `'remote'` and only has a single unit `remote/0`, and nobody has written any data to the databags yet. + +That is typically the state of a relation when the first unit joins it. + +When you use `Relation`, you are specifying a regular (conventional) relation. But that is not the only type of relation. There are also +peer relations and subordinate relations. While in the background the data model is the same, the data access rules and the consistency constraints on them are very different. For example, it does not make sense for a peer relation to have a different 'remote app' than its 'local app', because it's the same application. + +### PeerRelation +To declare a peer relation, you should use `scenario.state.PeerRelation`. +The core difference with regular relations is that peer relations do not have a "remote app" (it's this app, in fact). +So unlike `Relation`, a `PeerRelation` does not have `remote_app_name` or `remote_app_data` arguments. Also, it talks in terms of `peers`: +- `Relation.remote_unit_ids` maps to `PeerRelation.peers_ids` +- `Relation.remote_units_data` maps to `PeerRelation.peers_data` + +```python +from scenario.state import PeerRelation + +relation = PeerRelation( + endpoint="peers", + peers_data={1: {}, 2: {}, 42: {'foo': 'bar'}}, +) +``` + +be mindful when using `PeerRelation` not to include **"this unit"**'s ID in `peers_data` or `peers_ids`, as that would be flagged by the Consistency Checker: +```python +from scenario import State, PeerRelation + +State(relations=[ + PeerRelation( + endpoint="peers", + peers_data={1: {}, 2: {}, 42: {'foo': 'bar'}}, + )]).trigger("start", ..., unit_id=1) # invalid: this unit's id cannot be the ID of a peer. + + +``` + +### SubordinateRelation +To declare a subordinate relation, you should use `scenario.state.SubordinateRelation`. +The core difference with regular relations is that subordinate relations always have exactly one remote unit (there is always exactly one primary unit that this unit can see). +So unlike `Relation`, a `SubordinateRelation` does not have a `remote_units_data` argument. Instead, it has a `remote_unit_data` taking a single `Dict[str:str]`, and takes the primary unit ID as a separate argument. +Also, it talks in terms of `primary`: +- `Relation.remote_unit_ids` becomes `SubordinateRelation.primary_id` (a single ID instead of a list of IDs) +- `Relation.remote_units_data` becomes `SubordinateRelation.remote_unit_data` (a single databag instead of a mapping from unit IDs to databags) +- `Relation.remote_app_name` maps to `SubordinateRelation.primary_app_name` + +```python +from scenario.state import SubordinateRelation + +relation = SubordinateRelation( + endpoint="peers", + remote_unit_data={"foo": "bar"}, + primary_app_name="zookeeper", + primary_id=42 +) +relation.primary_name # "zookeeper/42" +``` + + +## Triggering Relation Events +If you want to trigger relation events, the easiest way to do so is get a hold of the Relation instance and grab the event from one of its aptly-named properties: + +```python +from scenario import Relation +relation = Relation(endpoint="foo", interface="bar") +changed_event = relation.changed_event +joined_event = relation.joined_event +# ... +``` + +This is in fact syntactic sugar for: +```python +from scenario import Relation, Event +relation = Relation(endpoint="foo", interface="bar") +changed_event = Event('foo-relation-changed', relation=relation) +``` + +The reason for this construction is that the event is associated with some relation-specific metadata, that Scenario needs to set up the process that will run `ops.main` with the right environment variables. + +### Additional event parameters +All relation events have some additional metadata that does not belong in the Relation object, such as, for a relation-joined event, the name of the (remote) unit that is joining the relation. That is what determines what `ops.model.Unit` you get when you get `RelationJoinedEvent().unit` in an event handler. + +In order to supply this parameter, you will have to **call** the event object and pass as `remote_unit_id` the id of the remote unit that the event is about. +The reason that this parameter is not supplied to `Relation` but to relation events, is that the relation already ties 'this app' to some 'remote app' (cfr. the `Relation.remote_app_name` attr), but not to a specific unit. What remote unit this event is about is not a `State` concern but an `Event` one. + +The `remote_unit_id` will default to the first ID found in the relation's `remote_unit_ids`, but if the test you are writing is close to that domain, you should probably override it and pass it manually. + +```python +from scenario import Relation, Event +relation = Relation(endpoint="foo", interface="bar") +remote_unit_2_is_joining_event = relation.joined_event(remote_unit_id=2) + +# which is syntactic sugar for: +remote_unit_2_is_joining_event = Event('foo-relation-changed', relation=relation, relation_remote_unit_id=2) +``` + + ## Containers When testing a kubernetes charm, you can mock container interactions. diff --git a/pyproject.toml b/pyproject.toml index f93c900e..f8936873 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "ops-scenario" -version = "2.1.3.3" +version = "2.1.3.4" authors = [ { name = "Pietro Pasotti", email = "pietro.pasotti@canonical.com" } ] diff --git a/scenario/consistency_checker.py b/scenario/consistency_checker.py index 0c1b6759..98ee93ad 100644 --- a/scenario/consistency_checker.py +++ b/scenario/consistency_checker.py @@ -1,9 +1,11 @@ import os +from collections import Counter +from itertools import chain from typing import TYPE_CHECKING, Iterable, NamedTuple, Tuple from scenario.runtime import InconsistentScenarioError from scenario.runtime import logger as scenario_logger -from scenario.state import _CharmSpec, normalize_name +from scenario.state import PeerRelation, SubordinateRelation, _CharmSpec, normalize_name if TYPE_CHECKING: from scenario.state import Event, State @@ -51,6 +53,7 @@ def check_consistency( check_config_consistency, check_event_consistency, check_secrets_consistency, + check_relation_consistency, ): results = check( state=state, event=event, charm_spec=charm_spec, juju_version=juju_version @@ -179,6 +182,61 @@ def check_secrets_consistency( return Results(errors, []) +def check_relation_consistency( + *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs +) -> Results: + errors = [] + nonpeer_relations_meta = chain( + charm_spec.meta.get("requires", {}).items(), + charm_spec.meta.get("provides", {}).items(), + ) + peer_relations_meta = charm_spec.meta.get("peers", {}).items() + all_relations_meta = list(chain(nonpeer_relations_meta, peer_relations_meta)) + + def _get_relations(r): + try: + return state.get_relations(r) + except ValueError: + return () + + # check relation types + for endpoint, _ in peer_relations_meta: + for relation in _get_relations(endpoint): + if not isinstance(relation, PeerRelation): + errors.append( + f"endpoint {endpoint} is a peer relation; " + f"expecting relation to be of type PeerRelation, got {type(relation)}" + ) + + for endpoint, relation_meta in all_relations_meta: + expected_sub = relation_meta.get("scope", "") == "container" + relations = _get_relations(endpoint) + for relation in relations: + is_sub = isinstance(relation, SubordinateRelation) + if is_sub and not expected_sub: + errors.append( + f"endpoint {endpoint} is not a subordinate relation; " + f"expecting relation to be of type Relation, " + f"got {type(relation)}" + ) + if expected_sub and not is_sub: + errors.append( + f"endpoint {endpoint} is not a subordinate relation; " + f"expecting relation to be of type SubordinateRelation, " + f"got {type(relation)}" + ) + + # check for duplicate endpoint names + seen_endpoints = set() + for endpoint, relation_meta in all_relations_meta: + if endpoint in seen_endpoints: + errors.append("duplicate endpoint name in metadata.") + break + seen_endpoints.add(endpoint) + + return Results(errors, []) + + def check_containers_consistency( *, state: "State", event: "Event", charm_spec: "_CharmSpec", **_kwargs ) -> Results: @@ -209,4 +267,10 @@ def check_containers_consistency( f"some containers declared in the state are not specified in metadata. That's not possible. " f"Missing from metadata: {diff}." ) + + # guard against duplicate container names + names = Counter(state_containers) + if dupes := [n for n in names if names[n] > 1]: + errors.append(f"Duplicate container name(s): {dupes}.") + return Results(errors, []) diff --git a/scenario/fs_mocks.py b/scenario/fs_mocks.py new file mode 100644 index 00000000..38548aec --- /dev/null +++ b/scenario/fs_mocks.py @@ -0,0 +1,35 @@ +import pathlib +from typing import Dict + +from ops.testing import _TestingFilesystem, _TestingStorageMount # noqa + + +# todo consider duplicating the filesystem on State.copy() to be able to diff and have true state snapshots +class _MockStorageMount(_TestingStorageMount): + def __init__(self, location: pathlib.PurePosixPath, src: pathlib.Path): + """Creates a new simulated storage mount. + + Args: + location: The path within simulated filesystem at which this storage will be mounted. + src: The temporary on-disk location where the simulated storage will live. + """ + self._src = src + self._location = location + + try: + # for some reason this fails if src exists, even though exists_ok=True. + super().__init__(location=location, src=src) + except FileExistsError: + pass + + +class _MockFileSystem(_TestingFilesystem): + def __init__(self, mounts: Dict[str, _MockStorageMount]): + super().__init__() + self._mounts = mounts + + def add_mount(self, *args, **kwargs): + raise NotImplementedError("Cannot mutate mounts; declare them all in State.") + + def remove_mount(self, *args, **kwargs): + raise NotImplementedError("Cannot mutate mounts; declare them all in State.") diff --git a/scenario/mocking.py b/scenario/mocking.py index 80069804..8bfda1c4 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -2,7 +2,6 @@ # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. import datetime -import pathlib import random from io import StringIO from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union @@ -10,13 +9,22 @@ from ops import pebble from ops.model import SecretInfo, SecretRotate, _ModelBackend from ops.pebble import Client, ExecError -from ops.testing import _TestingFilesystem, _TestingPebbleClient, _TestingStorageMount +from ops.testing import _TestingPebbleClient from scenario.logger import logger as scenario_logger +from scenario.state import PeerRelation if TYPE_CHECKING: from scenario.state import Container as ContainerSpec - from scenario.state import Event, ExecOutput, State, _CharmSpec + from scenario.state import ( + Event, + ExecOutput, + PeerRelation, + Relation, + State, + SubordinateRelation, + _CharmSpec, + ) logger = scenario_logger.getChild("mocking") @@ -62,7 +70,9 @@ def get_pebble(self, socket_path: str) -> "Client": charm_spec=self._charm_spec, ) - def _get_relation_by_id(self, rel_id): + def _get_relation_by_id( + self, rel_id + ) -> Union["Relation", "SubordinateRelation", "PeerRelation"]: try: return next( filter(lambda r: r.relation_id == rel_id, self._state.relations) @@ -101,9 +111,9 @@ def relation_get(self, rel_id, obj_name, app): return relation.remote_app_data elif obj_name == self.unit_name: return relation.local_unit_data - else: - unit_id = obj_name.split("/")[-1] - return relation.remote_units_data[int(unit_id)] + + unit_id = int(obj_name.split("/")[-1]) + return relation._get_databag_for_remote(unit_id) # noqa def is_leader(self): return self._state.leader @@ -119,11 +129,14 @@ def relation_ids(self, relation_name): if rel.endpoint == relation_name ] - def relation_list(self, relation_id: int): + def relation_list(self, relation_id: int) -> Tuple[str]: relation = self._get_relation_by_id(relation_id) + + if isinstance(relation, PeerRelation): + return tuple(f"{self.app_name}/{unit_id}" for unit_id in relation.peers_ids) return tuple( - f"{relation.remote_app_name}/{unit_id}" - for unit_id in relation.remote_unit_ids + f"{relation._remote_app_name}/{unit_id}" # noqa + for unit_id in relation._remote_unit_ids # noqa ) def config_get(self): @@ -317,35 +330,6 @@ def planned_units(self, *args, **kwargs): raise NotImplementedError("planned_units") -class _MockStorageMount(_TestingStorageMount): - def __init__(self, location: pathlib.PurePosixPath, src: pathlib.Path): - """Creates a new simulated storage mount. - - Args: - location: The path within simulated filesystem at which this storage will be mounted. - src: The temporary on-disk location where the simulated storage will live. - """ - self._src = src - self._location = location - if ( - not src.exists() - ): # we need to add this guard because the directory might exist already. - src.mkdir(exist_ok=True, parents=True) - - -# todo consider duplicating the filesystem on State.copy() to be able to diff and have true state snapshots -class _MockFileSystem(_TestingFilesystem): - def __init__(self, mounts: Dict[str, _MockStorageMount]): - super().__init__() - self._mounts = mounts - - def add_mount(self, *args, **kwargs): - raise NotImplementedError("Cannot mutate mounts; declare them all in State.") - - def remove_mount(self, *args, **kwargs): - raise NotImplementedError("Cannot mutate mounts; declare them all in State.") - - class _MockPebbleClient(_TestingPebbleClient): def __init__( self, diff --git a/scenario/ops_main_mock.py b/scenario/ops_main_mock.py index e7723ee3..85b34a78 100644 --- a/scenario/ops_main_mock.py +++ b/scenario/ops_main_mock.py @@ -14,7 +14,6 @@ from ops.main import CHARM_STATE_FILE, _Dispatcher, _emit_charm_event, _get_charm_dir from scenario.logger import logger as scenario_logger -from scenario.mocking import _MockModelBackend if TYPE_CHECKING: from ops.testing import CharmType @@ -38,6 +37,9 @@ def main( """Set up the charm and dispatch the observed event.""" charm_class = charm_spec.charm_type charm_dir = _get_charm_dir() + + from scenario.mocking import _MockModelBackend + model_backend = _MockModelBackend( # pyright: reportPrivateUsage=false state=state, event=event, charm_spec=charm_spec ) diff --git a/scenario/runtime.py b/scenario/runtime.py index 35757f3b..680f945d 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -25,11 +25,19 @@ from scenario.logger import logger as scenario_logger from scenario.ops_main_mock import NoObserverError +from scenario.state import DeferredEvent, PeerRelation, StoredState if TYPE_CHECKING: from ops.testing import CharmType - from scenario.state import DeferredEvent, Event, State, StoredState, _CharmSpec + from scenario.state import ( + AnyRelation, + DeferredEvent, + Event, + State, + StoredState, + _CharmSpec, + ) _CT = TypeVar("_CT", bound=Type[CharmType]) @@ -73,7 +81,6 @@ def _open_db(self) -> Optional[SQLiteStorage]: def get_stored_state(self) -> List["StoredState"]: """Load any StoredState data structures from the db.""" - from scenario.state import StoredState # avoid cyclic import db = self._open_db() @@ -92,7 +99,6 @@ def get_stored_state(self) -> List["StoredState"]: def get_deferred_events(self) -> List["DeferredEvent"]: """Load any DeferredEvent data structures from the db.""" - from scenario.state import DeferredEvent # avoid cyclic import db = self._open_db() @@ -139,6 +145,7 @@ def __init__( charm_spec: "_CharmSpec", charm_root: Optional["PathLike"] = None, juju_version: str = "3.0.0", + unit_id: int = 0, ): self._charm_spec = charm_spec self._juju_version = juju_version @@ -148,8 +155,9 @@ def __init__( if not app_name: raise ValueError('invalid metadata: mandatory "name" field is missing.') - # todo: consider parametrizing unit-id? cfr https://github.com/canonical/ops-scenario/issues/11 - self._unit_name = f"{app_name}/0" + self._app_name = app_name + self._unit_id = unit_id + self._unit_name = f"{app_name}/{unit_id}" @staticmethod def _cleanup_env(env): @@ -157,7 +165,8 @@ def _cleanup_env(env): # running this in a clean venv or a container anyway. # cleanup env, in case we'll be firing multiple events, we don't want to accumulate. for key in env: - os.unsetenv(key) + # os.unsetenv does not work !? + del os.environ[key] def _get_event_env(self, state: "State", event: "Event", charm_root: Path): if event.name.endswith("_action"): @@ -178,14 +187,48 @@ def _get_event_env(self, state: "State", event: "Event", charm_root: Path): # todo consider setting pwd, (python)path } - if relation := event.relation: + relation: "AnyRelation" + + if event._is_relation_event and (relation := event.relation): # noqa + if isinstance(relation, PeerRelation): + remote_app_name = self._app_name + else: + remote_app_name = relation._remote_app_name # noqa env.update( { "JUJU_RELATION": relation.endpoint, "JUJU_RELATION_ID": str(relation.relation_id), + "JUJU_REMOTE_APP": remote_app_name, } ) + remote_unit_id = event.relation_remote_unit_id + if ( + remote_unit_id is None + ): # don't check truthiness because it could be int(0) + remote_unit_ids = relation._remote_unit_ids # noqa + + if len(remote_unit_ids) == 1: + remote_unit_id = remote_unit_ids[0] + logger.info( + "there's only one remote unit, so we set JUJU_REMOTE_UNIT to it, " + "but you probably should be parametrizing the event with `remote_unit_id` " + "to be explicit." + ) + else: + remote_unit_id = remote_unit_ids[0] + logger.warning( + "remote unit ID unset, and multiple remote unit IDs are present; " + "We will pick the first one and hope for the best. You should be passing " + "`remote_unit_id` to the Event constructor." + ) + + if remote_unit_id is not None: + remote_unit = f"{remote_app_name}/{remote_unit_id}" + env["JUJU_REMOTE_UNIT"] = remote_unit + if event.name.endswith("_relation_departed"): + env["JUJU_DEPARTING_UNIT"] = remote_unit + if container := event.container: env.update({"JUJU_WORKLOAD_NAME": container.name}) @@ -348,7 +391,7 @@ def exec( finally: logger.info(" - Exited ops.main.") - logger.info(" - clearing env") + logger.info(" - Clearing env") self._cleanup_env(env) logger.info(" - closing storage") @@ -370,6 +413,7 @@ def trigger( config: Optional[Dict[str, Any]] = None, charm_root: Optional[Dict["PathLike", "PathLike"]] = None, juju_version: str = "3.0", + unit_id: int = 0, ) -> "State": """Trigger a charm execution with an Event and a State. @@ -391,6 +435,7 @@ def trigger( :arg config: charm config to use. Needs to be a valid config.yaml format (as a python dict). If none is provided, we will search for a ``config.yaml`` file in the charm root. :arg juju_version: Juju agent version to simulate. + :arg unit_id: The ID of the Juju unit that is charm execution is running on. :arg charm_root: virtual charm root the charm will be executed with. If the charm, say, expects a `./src/foo/bar.yaml` file present relative to the execution cwd, you need to use this. E.g.: @@ -422,6 +467,7 @@ def trigger( charm_spec=spec, juju_version=juju_version, charm_root=charm_root, + unit_id=unit_id, ) return runtime.exec( diff --git a/scenario/sequences.py b/scenario/sequences.py index 04044126..fa30b4dc 100644 --- a/scenario/sequences.py +++ b/scenario/sequences.py @@ -96,6 +96,7 @@ def check_builtin_sequences( template_state: State = None, pre_event: Optional[Callable[["CharmType"], None]] = None, post_event: Optional[Callable[["CharmType"], None]] = None, + unit_id: int = 0, ): """Test that all the builtin startup and teardown events can fire without errors. @@ -124,4 +125,5 @@ def check_builtin_sequences( config=config, pre_event=pre_event, post_event=post_event, + unit_id=unit_id, ) diff --git a/scenario/state.py b/scenario/state.py index 50160cb1..9b453d8d 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -17,9 +17,8 @@ from ops.charm import CharmEvents from ops.model import SecretRotate, StatusBase +from scenario.fs_mocks import _MockFileSystem, _MockStorageMount from scenario.logger import logger as scenario_logger -from scenario.mocking import _MockFileSystem, _MockStorageMount -from scenario.runtime import trigger as _runtime_trigger if typing.TYPE_CHECKING: try: @@ -29,6 +28,7 @@ from ops.testing import CharmType PathLike = Union[str, Path] + AnyRelation = Union["Relation", "PeerRelation", "SubordinateRelation"] logger = scenario_logger.getChild("state") @@ -63,6 +63,13 @@ } +class StateValidationError(RuntimeError): + """Raised when individual parts of the State are inconsistent.""" + + # as opposed to InconsistentScenario error where the + # **combination** of several parts of the State are. + + @dataclasses.dataclass class _DCBase: def replace(self, *args, **kwargs): @@ -144,19 +151,30 @@ def normalize_name(s: str): return s.replace("-", "_") -@dataclasses.dataclass -class Relation(_DCBase): - endpoint: str - remote_app_name: str = "remote" - remote_unit_ids: List[int] = dataclasses.field(default_factory=list) +class ParametrizedEvent: + def __init__(self, accept_params: Tuple[str], category: str, *args, **kwargs): + self._accept_params = accept_params + self._category = category + self._args = args + self._kwargs = kwargs - # local limit - limit: int = 1 + def __call__(self, remote_unit: Optional[str] = None) -> "Event": + """Construct an Event object using the arguments provided at init and any extra params.""" + if remote_unit and "remote_unit" not in self._accept_params: + raise ValueError( + f"cannot pass param `remote_unit` to a " + f"{self._category} event constructor." + ) - # scale of the remote application; number of units, leader ID? - # TODO figure out if this is relevant - scale: int = 1 - leader_id: int = 0 + return Event(*self._args, *self._kwargs, relation_remote_unit_id=remote_unit) + + def deferred(self, handler: Callable, event_id: int = 1) -> "DeferredEvent": + return self().deferred(handler=handler, event_id=event_id) + + +@dataclasses.dataclass +class RelationBase(_DCBase): + endpoint: str # we can derive this from the charm's metadata interface: str = None @@ -165,67 +183,251 @@ class Relation(_DCBase): relation_id: int = -1 local_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) - remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) local_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) - remote_units_data: Dict[int, Dict[str, str]] = dataclasses.field( - default_factory=dict - ) + + @property + def _databags(self): + """Yield all databags in this relation.""" + yield self.local_app_data + yield self.local_unit_data + + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + raise NotImplementedError() + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + raise NotImplementedError() + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + raise NotImplementedError() def __post_init__(self): + if type(self) is RelationBase: + raise RuntimeError( + "RelationBase cannot be instantiated directly; " + "please use Relation, PeerRelation, or SubordinateRelation" + ) + global _RELATION_IDS_CTR if self.relation_id == -1: _RELATION_IDS_CTR += 1 + logger.info( + f"relation ID unset; automatically assigning {_RELATION_IDS_CTR}. " + f"If there are problems, pass one manually." + ) self.relation_id = _RELATION_IDS_CTR - if self.remote_unit_ids and self.remote_units_data: - if not set(self.remote_unit_ids) == set(self.remote_units_data): - raise ValueError( - f"{self.remote_unit_ids} should include any and all IDs from {self.remote_units_data}" + for databag in self._databags: + self._validate_databag(databag) + + def _validate_databag(self, databag: dict): + if not isinstance(databag, dict): + raise StateValidationError( + f"all databags should be dicts, not {type(databag)}" + ) + for k, v in databag.items(): + if not isinstance(v, str): + raise StateValidationError( + f"all databags should be Dict[str,str]; " + f"found a value of type {type(v)}" ) - elif self.remote_unit_ids: - self.remote_units_data = {x: {} for x in self.remote_unit_ids} - elif self.remote_units_data: - self.remote_unit_ids = [x for x in self.remote_units_data] - else: - self.remote_unit_ids = [0] - self.remote_units_data = {0: {}} @property - def changed_event(self): + def changed_event(self) -> "Event": """Sugar to generate a -relation-changed event.""" return Event( name=normalize_name(self.endpoint + "-relation-changed"), relation=self ) @property - def joined_event(self): + def joined_event(self) -> "Event": """Sugar to generate a -relation-joined event.""" return Event( name=normalize_name(self.endpoint + "-relation-joined"), relation=self ) @property - def created_event(self): + def created_event(self) -> "Event": """Sugar to generate a -relation-created event.""" return Event( name=normalize_name(self.endpoint + "-relation-created"), relation=self ) @property - def departed_event(self): + def departed_event(self) -> "Event": """Sugar to generate a -relation-departed event.""" return Event( name=normalize_name(self.endpoint + "-relation-departed"), relation=self ) @property - def broken_event(self): + def broken_event(self) -> "Event": """Sugar to generate a -relation-broken event.""" return Event( name=normalize_name(self.endpoint + "-relation-broken"), relation=self ) +def unify_ids_and_remote_units_data(ids: List[int], data: Dict[int, Any]): + """Unify and validate a list of unit IDs and a mapping from said ids to databag contents. + + This allows the user to pass equivalently: + ids = [] + data = {1: {}} + + or + + ids = [1] + data = {} + + or + + ids = [1] + data = {1: {}} + + but catch the inconsistent: + + ids = [1] + data = {2: {}} + + or + + ids = [2] + data = {1: {}} + """ + if ids and data: + if not set(ids) == set(data): + raise StateValidationError( + f"{ids} should include any and all IDs from {data}" + ) + elif ids: + data = {x: {} for x in ids} + elif data: + ids = [x for x in data] + else: + ids = [0] + data = {0: {}} + return ids, data + + +@dataclasses.dataclass +class Relation(RelationBase): + remote_app_name: str = "remote" + remote_unit_ids: List[int] = dataclasses.field(default_factory=list) + + # local limit + limit: int = 1 + + remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) + remote_units_data: Dict[int, Dict[str, str]] = dataclasses.field( + default_factory=dict + ) + + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + return self.remote_app_name + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + return tuple(self.remote_unit_ids) + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + return self.remote_units_data[unit_id] + + @property + def _databags(self): + """Yield all databags in this relation.""" + yield self.local_app_data + yield self.local_unit_data + yield self.remote_app_data + yield from self.remote_units_data.values() + + def __post_init__(self): + super().__post_init__() + self.remote_unit_ids, self.remote_units_data = unify_ids_and_remote_units_data( + self.remote_unit_ids, self.remote_units_data + ) + + +@dataclasses.dataclass +class SubordinateRelation(RelationBase): + # todo: consider renaming them to primary_*_data + remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) + remote_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) + + # app name and ID of the primary that *this unit* is attached to. + primary_app_name: str = "remote" + primary_id: int = 0 + + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + return self.primary_app_name + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + return (self.primary_id,) + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + return self.remote_unit_data + + @property + def _databags(self): + """Yield all databags in this relation.""" + yield self.local_app_data + yield self.local_unit_data + yield self.remote_app_data + yield self.remote_unit_data + + @property + def primary_name(self) -> str: + return f"{self.primary_app_name}/{self.primary_id}" + + +@dataclasses.dataclass +class PeerRelation(RelationBase): + peers_data: Dict[int, Dict[str, str]] = dataclasses.field(default_factory=dict) + + # IDs of the peers. Consistency checks will validate that *this unit*'s ID is not in here. + peers_ids: List[int] = dataclasses.field(default_factory=list) + + @property + def _databags(self): + """Yield all databags in this relation.""" + yield self.local_app_data + yield self.local_unit_data + yield from self.peers_data.values() + + @property + def _remote_app_name(self) -> str: + """Who is on the other end of this relation?""" + # surprise! It's myself. + raise ValueError("peer relations don't quite have a remote end.") + + @property + def _remote_unit_ids(self) -> Tuple[int]: + """Ids of the units on the other end of this relation.""" + return tuple(self.peers_ids) + + def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: + """Return the databag for some remote unit ID.""" + return self.peers_data[unit_id] + + def __post_init__(self): + self.peers_ids, self.peers_data = unify_ids_and_remote_units_data( + self.peers_ids, self.peers_data + ) + + def _random_model_name(): import random import string @@ -361,7 +563,7 @@ def services(self) -> Dict[str, pebble.ServiceInfo]: return infos @property - def filesystem(self) -> _MockFileSystem: + def filesystem(self) -> "_MockFileSystem": mounts = { name: _MockStorageMount( src=Path(spec.src), location=PurePosixPath(spec.location) @@ -562,7 +764,7 @@ class State(_DCBase): config: Dict[str, Union[str, int, float, bool]] = dataclasses.field( default_factory=dict ) - relations: List[Relation] = dataclasses.field(default_factory=list) + relations: List["AnyRelation"] = dataclasses.field(default_factory=list) networks: List[Network] = dataclasses.field(default_factory=list) containers: List[Container] = dataclasses.field(default_factory=list) status: Status = dataclasses.field(default_factory=Status) @@ -608,7 +810,14 @@ def get_container(self, container: Union[str, Container]) -> Container: except StopIteration as e: raise ValueError(f"container: {name}") from e - # FIXME: not a great way to obtain a delta, but is "complete" todo figure out a better way. + def get_relations(self, endpoint: str) -> Tuple["AnyRelation"]: + """Get relation from this State, based on an input relation or its endpoint name.""" + try: + return tuple(filter(lambda c: c.endpoint == endpoint, self.relations)) + except StopIteration as e: + raise ValueError(f"relation: {endpoint}") from e + + # FIXME: not a great way to obtain a delta, but is "complete". todo figure out a better way. def jsonpatch_delta(self, other: "State"): try: import jsonpatch @@ -637,8 +846,11 @@ def trigger( config: Optional[Dict[str, Any]] = None, charm_root: Optional["PathLike"] = None, juju_version: str = "3.0", + unit_id: int = 0, ) -> "State": """Fluent API for trigger. See runtime.trigger's docstring.""" + from scenario.runtime import trigger as _runtime_trigger + return _runtime_trigger( state=self, event=event, @@ -650,10 +862,9 @@ def trigger( config=config, charm_root=charm_root, juju_version=juju_version, + unit_id=unit_id, ) - trigger.__doc__ = _runtime_trigger.__doc__ - @dataclasses.dataclass class _CharmSpec(_DCBase): @@ -720,7 +931,9 @@ class Event(_DCBase): kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # if this is a relation event, the relation it refers to - relation: Optional[Relation] = None + relation: Optional["AnyRelation"] = None + # and the name of the remote unit this relation event is about + relation_remote_unit_id: Optional[int] = None # if this is a secret event, the secret it refers to secret: Optional[Secret] = None @@ -733,6 +946,14 @@ class Event(_DCBase): # - pebble? # - action? + def __call__(self, remote_unit_id: Optional[int] = None) -> "Event": + if remote_unit_id and not self._is_relation_event: + raise ValueError( + "cannot pass param `remote_unit_id` to a " + "non-relation event constructor." + ) + return self.replace(relation_remote_unit_id=remote_unit_id) + def __post_init__(self): if "-" in self.name: logger.warning(f"Only use underscores in event names. {self.name!r}") @@ -834,9 +1055,9 @@ def deferred(self, handler: Callable, event_id: int = 1) -> DeferredEvent: # this is a RelationEvent. The snapshot: snapshot_data = { "relation_name": self.relation.endpoint, - "relation_id": self.relation.relation_id - # 'app_name': local app name - # 'unit_name': local unit name + "relation_id": self.relation.relation_id, + "app_name": self.relation.remote_app_name, + "unit_name": f"{self.relation.remote_app_name}/{self.relation_remote_unit_id}", } return DeferredEvent( diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 6e82119a..3a8511f9 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -7,9 +7,11 @@ RELATION_EVENTS_SUFFIX, Container, Event, + PeerRelation, Relation, Secret, State, + SubordinateRelation, _CharmSpec, ) @@ -154,3 +156,49 @@ def test_secrets_jujuv_bad(good_v): _CharmSpec(MyCharm, {}), good_v, ) + + +def test_peer_relation_consistency(): + assert_inconsistent( + State(relations=[Relation("foo")]), + Event("bar"), + _CharmSpec(MyCharm, {"peers": {"foo": {"interface": "bar"}}}), + ) + assert_consistent( + State(relations=[PeerRelation("foo")]), + Event("bar"), + _CharmSpec(MyCharm, {"peers": {"foo": {"interface": "bar"}}}), + ) + + +def test_sub_relation_consistency(): + assert_inconsistent( + State(relations=[Relation("foo")]), + Event("bar"), + _CharmSpec( + MyCharm, {"requires": {"foo": {"interface": "bar", "scope": "container"}}} + ), + ) + assert_consistent( + State(relations=[SubordinateRelation("foo")]), + Event("bar"), + _CharmSpec( + MyCharm, {"requires": {"foo": {"interface": "bar", "scope": "container"}}} + ), + ) + + +def test_relation_sub_inconsistent(): + assert_inconsistent( + State(relations=[SubordinateRelation("foo")]), + Event("bar"), + _CharmSpec(MyCharm, {"requires": {"foo": {"interface": "bar"}}}), + ) + + +def test_dupe_containers_inconsistent(): + assert_inconsistent( + State(containers=[Container("foo"), Container("foo")]), + Event("bar"), + _CharmSpec(MyCharm, {"containers": {"foo": {}}}), + ) diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 58dd3205..7d429aa9 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -1,10 +1,17 @@ from typing import Type import pytest -from ops.charm import CharmBase, CharmEvents +from ops.charm import CharmBase, CharmEvents, RelationDepartedEvent from ops.framework import EventBase, Framework -from scenario.state import Relation, State +from scenario.state import ( + PeerRelation, + Relation, + RelationBase, + State, + StateValidationError, + SubordinateRelation, +) @pytest.fixture(scope="function") @@ -124,3 +131,162 @@ def callback(charm: CharmBase, _): }, }, ) + + +@pytest.mark.parametrize( + "evt_name", + ("changed", "broken", "departed", "joined", "created"), +) +@pytest.mark.parametrize( + "remote_app_name", + ("remote", "prometheus", "aodeok123"), +) +@pytest.mark.parametrize( + "remote_unit_id", + (0, 1), +) +def test_relation_events_attrs(mycharm, evt_name, remote_app_name, remote_unit_id): + relation = Relation( + endpoint="foo", interface="foo", remote_app_name=remote_app_name + ) + + def callback(charm: CharmBase, event): + assert event.app + assert event.unit + if isinstance(event, RelationDepartedEvent): + assert event.departing_unit + + mycharm._call = callback + + State( + relations=[ + relation, + ], + ).trigger( + getattr(relation, f"{evt_name}_event")(remote_unit_id=remote_unit_id), + mycharm, + meta={ + "name": "local", + "requires": { + "foo": {"interface": "foo"}, + }, + }, + ) + + +@pytest.mark.parametrize( + "evt_name", + ("changed", "broken", "departed", "joined", "created"), +) +@pytest.mark.parametrize( + "remote_app_name", + ("remote", "prometheus", "aodeok123"), +) +def test_relation_events_no_attrs(mycharm, evt_name, remote_app_name, caplog): + relation = Relation( + endpoint="foo", + interface="foo", + remote_app_name=remote_app_name, + remote_units_data={0: {}, 1: {}}, # 2 units + ) + + def callback(charm: CharmBase, event): + assert event.app # that's always present + assert event.unit + assert (evt_name == "departed") is bool(getattr(event, "departing_unit", False)) + + mycharm._call = callback + + State( + relations=[ + relation, + ], + ).trigger( + getattr(relation, f"{evt_name}_event"), + mycharm, + meta={ + "name": "local", + "requires": { + "foo": {"interface": "foo"}, + }, + }, + ) + + assert ( + "remote unit ID unset, and multiple remote unit IDs are present" in caplog.text + ) + + +@pytest.mark.parametrize("data", (set(), {}, [], (), 1, 1.0, None, b"")) +def test_relation_unit_data_bad_types(mycharm, data): + with pytest.raises(StateValidationError): + relation = Relation( + endpoint="foo", interface="foo", remote_units_data={0: {"a": data}} + ) + + +@pytest.mark.parametrize("data", (set(), {}, [], (), 1, 1.0, None, b"")) +def test_relation_app_data_bad_types(mycharm, data): + with pytest.raises(StateValidationError): + relation = Relation(endpoint="foo", interface="foo", local_app_data={"a": data}) + + +@pytest.mark.parametrize( + "evt_name", + ("changed", "broken", "departed", "joined", "created"), +) +@pytest.mark.parametrize( + "relation", + (Relation("a"), PeerRelation("b"), SubordinateRelation("c")), +) +def test_relation_event_trigger(relation, evt_name, mycharm): + meta = { + "name": "mycharm", + "requires": {"a": {"interface": "i1"}}, + "provides": { + "c": { + "interface": "i3", + # this is a subordinate relation. + "scope": "container", + } + }, + "peers": {"b": {"interface": "i2"}}, + } + state = State(relations=[relation]).trigger( + getattr(relation, evt_name + "_event"), mycharm, meta=meta + ) + + +def test_trigger_sub_relation(mycharm): + meta = { + "name": "mycharm", + "provides": { + "foo": { + "interface": "bar", + # this is a subordinate relation. + "scope": "container", + } + }, + } + + sub1 = SubordinateRelation( + "foo", remote_unit_data={"1": "2"}, primary_app_name="primary1" + ) + sub2 = SubordinateRelation( + "foo", remote_unit_data={"3": "4"}, primary_app_name="primary2" + ) + + def post_event(charm: CharmBase): + b_relations = charm.model.relations["foo"] + assert len(b_relations) == 2 + for relation in b_relations: + assert len(relation.units) == 1 + + State(relations=[sub1, sub2]).trigger( + "update-status", mycharm, meta=meta, post_event=post_event + ) + + +def test_cannot_instantiate_relationbase(): + with pytest.raises(RuntimeError): + RelationBase("") diff --git a/tests/test_runtime.py b/tests/test_runtime.py index d6ef9be1..954ec934 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -88,7 +88,8 @@ class MyEvt(EventBase): @pytest.mark.parametrize("app_name", ("foo", "bar-baz", "QuX2")) -def test_unit_name(app_name): +@pytest.mark.parametrize("unit_id", (1, 2, 42)) +def test_unit_name(app_name, unit_id): meta = { "name": app_name, } @@ -100,9 +101,10 @@ def test_unit_name(app_name): my_charm_type, meta=meta, ), + unit_id=unit_id, ) def post_event(charm: CharmBase): - assert charm.unit.name == f"{app_name}/0" + assert charm.unit.name == f"{app_name}/{unit_id}" runtime.exec(state=State(), event=Event("start"), post_event=post_event) diff --git a/tox.ini b/tox.ini index 8689e828..a25b8f0f 100644 --- a/tox.ini +++ b/tox.ini @@ -28,6 +28,7 @@ commands = [testenv:lint] +skip_install=True description = lint deps = coverage[toml] @@ -41,6 +42,7 @@ commands = [testenv:fmt] +skip_install=True description = Format code deps = black