diff --git a/tdp/core/variables/cluster_variables.py b/tdp/core/variables/cluster_variables.py index 1379b896..fe016221 100644 --- a/tdp/core/variables/cluster_variables.py +++ b/tdp/core/variables/cluster_variables.py @@ -116,8 +116,9 @@ def initialize_cluster_variables( if service in services_initialized_by_this_function: try: - service_variables.update_from_variables_folder( - "add variables from " + collection_name, path + service_variables.update_from_dir( + path, + validation_message="add variables from " + collection_name, ) except EmptyCommit: logger.warning( diff --git a/tdp/core/variables/service_variables.py b/tdp/core/variables/service_variables.py index bfc21393..519e125c 100644 --- a/tdp/core/variables/service_variables.py +++ b/tdp/core/variables/service_variables.py @@ -5,7 +5,7 @@ import logging from collections import OrderedDict -from collections.abc import Generator +from collections.abc import Generator, Iterable from contextlib import ExitStack, contextmanager from pathlib import Path from typing import TYPE_CHECKING @@ -14,11 +14,15 @@ from tdp.core.operation import SERVICE_NAME_MAX_LENGTH from tdp.core.types import PathLike from tdp.core.variables.schema import validate_against_schema -from tdp.core.variables.variables import Variables, VariablesDict +from tdp.core.variables.variables import ( + Variables, + VariablesDict, +) if TYPE_CHECKING: from tdp.core.repository.repository import Repository from tdp.core.service_component_name import ServiceComponentName + from tdp.core.variables.variables import _VariablesIOWrapper logger = logging.getLogger(__name__) @@ -89,78 +93,72 @@ def get_variables(self, component_name: str) -> dict: with Variables(component_path).open("r") as variables: return variables.copy() - def update_from_variables_folder( - self, message: str, tdp_vars_overrides: PathLike + def update_from_dir( + self, input_dir: PathLike, /, *, validation_message: str ) -> None: - """Update the variables repository from an overrides file. + """Update the service variables from an input directory. - Args: - message: Validation message. - tdp_vars_overrides: Overrides file path. - """ - override_files = list(Path(tdp_vars_overrides).glob("*" + YML_EXTENSION)) - service_files_to_open = [override_file.name for override_file in override_files] - with self.open_var_files(f"{message}", service_files_to_open) as configurations: - for file in override_files: - logger.info(f"Updating {self.name} with variables from {file}") - with Variables(file).open("r") as variables: - configurations[file.name].merge(variables) + Input variables are merged with the existing ones. If a variable file is not + present in the repository, it is created. If a variable file is present in the + repository but not in the input directory, it is not modified. - @contextmanager - def _open_var_file( - self, path: PathLike, fail_if_does_not_exist: bool = False - ) -> Variables: - """Context manager to facilitate the opening a variables file. - - Provides a Variables object automatically closed when parent context manager closes it. + Changes are persisted to the `tdp_vars` service repository using the given + validation message. Args: - path: Path of the variables file to open. - fail_if_does_not_exist: Whether or not the function should raise an error when file does not exist. - - Yields: - A weakref of the Variables object, to prevent the creation of strong references - outside the caller's context. - - Raises: - ValueError: If the file does not exist and fail_if_does_not_exist is True. + message: Validation message to use for the repository. + input_dir: Path to the directory containing the variables files to import. """ - path = self.path / path - path.parent.mkdir(parents=True, exist_ok=True) - if not path.exists(): - if fail_if_does_not_exist: - raise ValueError("Path does not exist") - path.touch() - with Variables(path).open() as variables: - yield variables + input_file_paths = Path(input_dir).glob("*" + YML_EXTENSION) + # Open corresponding files in the repository. + files_to_open = (input_file_path.name for input_file_path in input_file_paths) + with self.open_files( + files_to_open, validation_message=validation_message, create_if_missing=True + ) as files: + # Merge the input files into the repository files. + for input_file_path in input_file_paths: + with Variables(input_file_path).open("r") as input_file: + files[input_file_path.name].merge(input_file) @contextmanager - def open_var_files( - self, message: str, paths: list[str], fail_if_does_not_exist: bool = False - ) -> Generator[OrderedDict[str, Variables], None, None]: - """Open variables files. - - Adds the underlying files for validation. + def open_files( + self, + file_names: Iterable[str], + /, + *, + validation_message: str, + create_if_missing: bool = False, + ) -> Generator[OrderedDict[str, _VariablesIOWrapper], None, None]: + """Open files in the service repository. + + Allow to open multiple files in the service repository at once in a context + manager. Files can be modified in the context manager. Changes are persisted to + the `tdp_vars` service repository using the given validation message. Args: - message: Validation message. - paths: List of paths to open. + validation_message: Validation message to use for the repository. + file_names: Names of the files to manage. + create_if_missing: Whether to create the file if it does not exist. Yields: - Variables as an OrderedDict where keys are sorted by the order of the input paths. + A dictionary of opened files. """ - with self.repository.validate(message), ExitStack() as stack: - yield OrderedDict( + with self.repository.validate(validation_message) as repo, ExitStack() as stack: + open_files = OrderedDict( ( - path, + file_name, + # Stack is used to properly close the files when exiting the + # context manager. stack.enter_context( - self._open_var_file(path, fail_if_does_not_exist) + Variables( + self.path / file_name, create_if_missing=create_if_missing + ).open() ), ) - for path in paths + for file_name in file_names ) - stack.close() - self.repository.add_for_validation(paths) + yield open_files + repo.add_for_validation(file_names) def is_sc_modified_from_version( self, service_component: ServiceComponentName, version: str diff --git a/tdp/core/variables/variables.py b/tdp/core/variables/variables.py index 716cd9d2..4133aea2 100644 --- a/tdp/core/variables/variables.py +++ b/tdp/core/variables/variables.py @@ -26,13 +26,20 @@ class Variables: del variables["key1"] # deletes value at key `key1` """ - def __init__(self, file_path: PathLike): + def __init__(self, file_path: PathLike, /, *, create_if_missing: bool = False): """Initializes a Variables instance. Args: file_path: Path to the file. + create_if_missing: Whether to create the file if it does not exist. """ self._file_path = Path(file_path) + # Create the file if it does not exist + if not self._file_path.exists(): + if not create_if_missing: + raise FileNotFoundError(f"'{file_path}' does not exist.") + self._file_path.parent.mkdir(parents=True, exist_ok=True) + self._file_path.touch() def open(self, mode: Optional[str] = None) -> "_VariablesIOWrapper": """Opens the file in the given mode. @@ -114,7 +121,7 @@ def __init__(self, path: Path, mode: Optional[str] = None): self._content = from_yaml(self._file_descriptor) or {} self._name = path.name - def __enter__(self): + def __enter__(self) -> "_VariablesIOWrapper": return proxy(self) def __exit__(self, exc_type, exc_val, exc_tb):