From 0c0f4b54139a9ee0a8dc1611a2b30027f8c689d8 Mon Sep 17 00:00:00 2001 From: Nuwan Goonasekera <2070605+nuwang@users.noreply.github.com> Date: Wed, 15 Jun 2022 00:23:22 +0530 Subject: [PATCH] Add support for defining local contexts --- tests/fixtures/mapping-context.yml | 61 ++++++++++++++++++++++++++++++ tests/test_mapper_context.py | 39 +++++++++++++++++++ tests/test_mapper_inheritance.py | 2 +- tpv/core/entities.py | 42 ++++++++++++-------- 4 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 tests/fixtures/mapping-context.yml create mode 100644 tests/test_mapper_context.py diff --git a/tests/fixtures/mapping-context.yml b/tests/fixtures/mapping-context.yml new file mode 100644 index 0000000..0ed641a --- /dev/null +++ b/tests/fixtures/mapping-context.yml @@ -0,0 +1,61 @@ +global: + default_inherits: default + context: + small_jpb_cores: 1 + medium_job_cores: 2 + large_job_cores: 4 + small_input_size: 2 + medium_input_size: 10 + large_input_size: 20 + +tools: + default: + context: + medium_job_cores: 3 + medium_input_size: 12 + cores: medium_job_cores + mem: cores * 3 + gpus: 1 + env: + TEST_JOB_SLOTS: "{cores}" + params: + native_spec: "--mem {mem} --cores {cores} --gpus {gpus}" + scheduling: + require: [] + prefer: + - general + accept: + reject: + - pulsar + rules: + - if: input_size < small_input_size + fail: We don't run piddling datasets + bwa: + context: + medium_job_cores: 5 + gpus: 2 + scheduling: + require: + - pulsar + rules: + - if: input_size <= medium_input_size + gpus: 4 + - if: input_size >= large_input_size + fail: Too much data, shouldn't run + trinity: + gpus: 3 + +destinations: + local: + cores: 4 + mem: 16 + scheduling: + prefer: + - general + k8s_environment: + cores: 16 + mem: 64 + gpus: 5 + scheduling: + prefer: + - pulsar diff --git a/tests/test_mapper_context.py b/tests/test_mapper_context.py new file mode 100644 index 0000000..a44c333 --- /dev/null +++ b/tests/test_mapper_context.py @@ -0,0 +1,39 @@ +import os +import unittest +from tpv.rules import gateway +from . import mock_galaxy +from tpv.core.loader import InvalidParentException + + +class TestMapperContext(unittest.TestCase): + + @staticmethod + def _map_to_destination(tool, user, datasets, tpv_config_path=None): + galaxy_app = mock_galaxy.App() + job = mock_galaxy.Job() + for d in datasets: + job.add_input_dataset(d) + tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), + 'fixtures/mapping-context.yml') + gateway.ACTIVE_DESTINATION_MAPPER = None + return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) + + def test_map_context_default_overrides_global(self): + tool = mock_galaxy.Tool('trinity') + user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + + destination = self._map_to_destination(tool, user, datasets) + self.assertEqual(destination.id, "local") + self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['3']) + self.assertEqual(destination.params['native_spec'], '--mem 9 --cores 3 --gpus 3') + + def test_map_tool_overrides_default(self): + tool = mock_galaxy.Tool('bwa') + user = mock_galaxy.User('gargravarr', 'fairycake@vortex.org') + datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))] + + destination = self._map_to_destination(tool, user, datasets) + self.assertEqual(destination.id, "k8s_environment") + self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['5']) + self.assertEqual(destination.params['native_spec'], '--mem 15 --cores 5 --gpus 4') diff --git a/tests/test_mapper_inheritance.py b/tests/test_mapper_inheritance.py index 5593944..629e67d 100644 --- a/tests/test_mapper_inheritance.py +++ b/tests/test_mapper_inheritance.py @@ -14,7 +14,7 @@ def _map_to_destination(tool, user, datasets, tpv_config_path=None): for d in datasets: job.add_input_dataset(d) tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__), - 'fixtures/mapping-inheritance.yml') + 'fixtures/mapping-inheritance.yml') gateway.ACTIVE_DESTINATION_MAPPER = None return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config]) diff --git a/tpv/core/entities.py b/tpv/core/entities.py index 78c9660..6296848 100644 --- a/tpv/core/entities.py +++ b/tpv/core/entities.py @@ -163,7 +163,7 @@ def from_dict(tags: list[dict]) -> TagSetManager: class Entity(object): def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, params=None, resubmit=None, - tags=None, rank=None, inherits=None): + tags=None, rank=None, inherits=None, context=None): self.loader = loader self.id = id self.cores = cores @@ -175,6 +175,7 @@ def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, p self.tags = TagSetManager.from_dict(tags or {}) self.rank = rank self.inherits = inherits + self.context = context self.validate() def process_complex_property(self, prop, context, func): @@ -223,7 +224,7 @@ def validate(self): def __repr__(self): return f"{self.__class__} id={self.id}, cores={self.cores}, mem={self.mem}, gpus={self.gpus}, " \ f"env={self.env}, params={self.params}, resubmit={self.resubmit}, tags={self.tags}, " \ - f"rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}" + f"rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}, context={self.context}" def override(self, entity): new_entity = copy.copy(entity) @@ -239,6 +240,8 @@ def override(self, entity): new_entity.resubmit.update(self.resubmit or {}) new_entity.rank = self.rank if self.rank is not None else entity.rank new_entity.inherits = self.inherits if self.inherits is not None else entity.inherits + new_entity.context = copy.copy(entity.context) or {} + new_entity.context.update(self.context or {}) return new_entity def inherit(self, entity): @@ -308,6 +311,7 @@ def evaluate_early(self, context): :return: """ new_entity = copy.deepcopy(self) + context.update(self.context or {}) if self.gpus: new_entity.gpus = self.loader.eval_code_block(self.gpus, context) context['gpus'] = new_entity.gpus @@ -327,6 +331,7 @@ def evaluate_late(self, context): :return: """ new_entity = copy.deepcopy(self) + context.update(self.context or {}) context['gpus'] = new_entity.gpus context['cores'] = new_entity.cores context['mem'] = new_entity.mem @@ -360,8 +365,9 @@ def score(self, entity): class EntityWithRules(Entity): def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, - params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None): - super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, rank, inherits) + params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None): + super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit, + tags=tags, rank=rank, inherits=inherits, context=context) self.rules = self.validate_rules(rules) def validate_rules(self, rules: list) -> list: @@ -389,6 +395,7 @@ def from_dict(cls: type, loader, entity_dict): tags=entity_dict.get('scheduling'), rank=entity_dict.get('rank'), inherits=entity_dict.get('inherits'), + context=entity_dict.get('context'), rules=entity_dict.get('rules') ) @@ -432,31 +439,33 @@ def __repr__(self): class Tool(EntityWithRules): def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, - env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None): - super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, rank, inherits, rules) + env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None): + super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit, + tags=tags, rank=rank, inherits=inherits, context=context, rules=rules) class User(EntityWithRules): def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, - env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None): - super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, rank, inherits, rules) + env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None): + super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit, + tags=tags, rank=rank, inherits=inherits, context=context, rules=rules) class Role(EntityWithRules): def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, - env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None): + env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None): super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit, - tags=tags, rank=rank, inherits=inherits, rules=rules) + tags=tags, rank=rank, inherits=inherits, context=context, rules=rules) class Destination(EntityWithRules): def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, - env=None, params=None, resubmit=None, tags=None, inherits=None, rules=None): + env=None, params=None, resubmit=None, tags=None, inherits=None, context=None, rules=None): super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit, - tags=tags, inherits=inherits, rules=rules) + tags=tags, inherits=inherits, context=context, rules=rules) @staticmethod def from_dict(loader, entity_dict): @@ -471,6 +480,7 @@ def from_dict(loader, entity_dict): resubmit=entity_dict.get('resubmit'), tags=entity_dict.get('scheduling'), inherits=entity_dict.get('inherits'), + context=entity_dict.get('context'), rules=entity_dict.get('rules') ) @@ -479,12 +489,13 @@ class Rule(Entity): rule_counter = 0 - def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, - env=None, params=None, resubmit=None, tags=None, inherits=None, match=None, execute=None, fail=None): + def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, params=None, resubmit=None, + tags=None, inherits=None, context=None, match=None, execute=None, fail=None): if not id: Rule.rule_counter += 1 id = f"tpv_rule_{Rule.rule_counter}" - super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, inherits=inherits) + super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit, + tags=tags, context=context, inherits=inherits) self.match = match self.execute = execute self.fail = fail @@ -508,6 +519,7 @@ def from_dict(loader, entity_dict): resubmit=entity_dict.get('resubmit'), tags=entity_dict.get('scheduling'), inherits=entity_dict.get('inherits'), + context=entity_dict.get('context'), # TODO: Remove deprecated match clause in future match=entity_dict.get('if') or entity_dict.get('match'), execute=entity_dict.get('execute'),