Skip to content

Commit

Permalink
Merge pull request #28 from galaxyproject/add_local_context
Browse files Browse the repository at this point in the history
Add support for defining local contexts
  • Loading branch information
nuwang authored Jun 15, 2022
2 parents f7d6e40 + 0c0f4b5 commit b2a53ee
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 16 deletions.
61 changes: 61 additions & 0 deletions tests/fixtures/mapping-context.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
global:
default_inherits: default
context:
small_jpb_cores: 1
medium_job_cores: 2
large_job_cores: 4
small_input_size: 2
medium_input_size: 10
large_input_size: 20

tools:
default:
context:
medium_job_cores: 3
medium_input_size: 12
cores: medium_job_cores
mem: cores * 3
gpus: 1
env:
TEST_JOB_SLOTS: "{cores}"
params:
native_spec: "--mem {mem} --cores {cores} --gpus {gpus}"
scheduling:
require: []
prefer:
- general
accept:
reject:
- pulsar
rules:
- if: input_size < small_input_size
fail: We don't run piddling datasets
bwa:
context:
medium_job_cores: 5
gpus: 2
scheduling:
require:
- pulsar
rules:
- if: input_size <= medium_input_size
gpus: 4
- if: input_size >= large_input_size
fail: Too much data, shouldn't run
trinity:
gpus: 3

destinations:
local:
cores: 4
mem: 16
scheduling:
prefer:
- general
k8s_environment:
cores: 16
mem: 64
gpus: 5
scheduling:
prefer:
- pulsar
39 changes: 39 additions & 0 deletions tests/test_mapper_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import unittest
from tpv.rules import gateway
from . import mock_galaxy
from tpv.core.loader import InvalidParentException


class TestMapperContext(unittest.TestCase):

@staticmethod
def _map_to_destination(tool, user, datasets, tpv_config_path=None):
galaxy_app = mock_galaxy.App()
job = mock_galaxy.Job()
for d in datasets:
job.add_input_dataset(d)
tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__),
'fixtures/mapping-context.yml')
gateway.ACTIVE_DESTINATION_MAPPER = None
return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config])

def test_map_context_default_overrides_global(self):
tool = mock_galaxy.Tool('trinity')
user = mock_galaxy.User('gargravarr', '[email protected]')
datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))]

destination = self._map_to_destination(tool, user, datasets)
self.assertEqual(destination.id, "local")
self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['3'])
self.assertEqual(destination.params['native_spec'], '--mem 9 --cores 3 --gpus 3')

def test_map_tool_overrides_default(self):
tool = mock_galaxy.Tool('bwa')
user = mock_galaxy.User('gargravarr', '[email protected]')
datasets = [mock_galaxy.DatasetAssociation("test", mock_galaxy.Dataset("test.txt", file_size=5*1024**3))]

destination = self._map_to_destination(tool, user, datasets)
self.assertEqual(destination.id, "k8s_environment")
self.assertEqual([env['value'] for env in destination.env if env['name'] == 'TEST_JOB_SLOTS'], ['5'])
self.assertEqual(destination.params['native_spec'], '--mem 15 --cores 5 --gpus 4')
2 changes: 1 addition & 1 deletion tests/test_mapper_inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _map_to_destination(tool, user, datasets, tpv_config_path=None):
for d in datasets:
job.add_input_dataset(d)
tpv_config = tpv_config_path or os.path.join(os.path.dirname(__file__),
'fixtures/mapping-inheritance.yml')
'fixtures/mapping-inheritance.yml')
gateway.ACTIVE_DESTINATION_MAPPER = None
return gateway.map_tool_to_destination(galaxy_app, job, tool, user, tpv_config_files=[tpv_config])

Expand Down
42 changes: 27 additions & 15 deletions tpv/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def from_dict(tags: list[dict]) -> TagSetManager:
class Entity(object):

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, params=None, resubmit=None,
tags=None, rank=None, inherits=None):
tags=None, rank=None, inherits=None, context=None):
self.loader = loader
self.id = id
self.cores = cores
Expand All @@ -175,6 +175,7 @@ def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, p
self.tags = TagSetManager.from_dict(tags or {})
self.rank = rank
self.inherits = inherits
self.context = context
self.validate()

def process_complex_property(self, prop, context, func):
Expand Down Expand Up @@ -223,7 +224,7 @@ def validate(self):
def __repr__(self):
return f"{self.__class__} id={self.id}, cores={self.cores}, mem={self.mem}, gpus={self.gpus}, " \
f"env={self.env}, params={self.params}, resubmit={self.resubmit}, tags={self.tags}, " \
f"rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}"
f"rank={self.rank[:10] if self.rank else ''}, inherits={self.inherits}, context={self.context}"

def override(self, entity):
new_entity = copy.copy(entity)
Expand All @@ -239,6 +240,8 @@ def override(self, entity):
new_entity.resubmit.update(self.resubmit or {})
new_entity.rank = self.rank if self.rank is not None else entity.rank
new_entity.inherits = self.inherits if self.inherits is not None else entity.inherits
new_entity.context = copy.copy(entity.context) or {}
new_entity.context.update(self.context or {})
return new_entity

def inherit(self, entity):
Expand Down Expand Up @@ -308,6 +311,7 @@ def evaluate_early(self, context):
:return:
"""
new_entity = copy.deepcopy(self)
context.update(self.context or {})
if self.gpus:
new_entity.gpus = self.loader.eval_code_block(self.gpus, context)
context['gpus'] = new_entity.gpus
Expand All @@ -327,6 +331,7 @@ def evaluate_late(self, context):
:return:
"""
new_entity = copy.deepcopy(self)
context.update(self.context or {})
context['gpus'] = new_entity.gpus
context['cores'] = new_entity.cores
context['mem'] = new_entity.mem
Expand Down Expand Up @@ -360,8 +365,9 @@ def score(self, entity):
class EntityWithRules(Entity):

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None,
params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None):
super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, rank, inherits)
params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None):
super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit,
tags=tags, rank=rank, inherits=inherits, context=context)
self.rules = self.validate_rules(rules)

def validate_rules(self, rules: list) -> list:
Expand Down Expand Up @@ -389,6 +395,7 @@ def from_dict(cls: type, loader, entity_dict):
tags=entity_dict.get('scheduling'),
rank=entity_dict.get('rank'),
inherits=entity_dict.get('inherits'),
context=entity_dict.get('context'),
rules=entity_dict.get('rules')
)

Expand Down Expand Up @@ -432,31 +439,33 @@ def __repr__(self):
class Tool(EntityWithRules):

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None,
env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None):
super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, rank, inherits, rules)
env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None):
super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit,
tags=tags, rank=rank, inherits=inherits, context=context, rules=rules)


class User(EntityWithRules):

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None,
env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None):
super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, rank, inherits, rules)
env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None):
super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit,
tags=tags, rank=rank, inherits=inherits, context=context, rules=rules)


class Role(EntityWithRules):

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None,
env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, rules=None):
env=None, params=None, resubmit=None, tags=None, rank=None, inherits=None, context=None, rules=None):
super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit,
tags=tags, rank=rank, inherits=inherits, rules=rules)
tags=tags, rank=rank, inherits=inherits, context=context, rules=rules)


class Destination(EntityWithRules):

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None,
env=None, params=None, resubmit=None, tags=None, inherits=None, rules=None):
env=None, params=None, resubmit=None, tags=None, inherits=None, context=None, rules=None):
super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit,
tags=tags, inherits=inherits, rules=rules)
tags=tags, inherits=inherits, context=context, rules=rules)

@staticmethod
def from_dict(loader, entity_dict):
Expand All @@ -471,6 +480,7 @@ def from_dict(loader, entity_dict):
resubmit=entity_dict.get('resubmit'),
tags=entity_dict.get('scheduling'),
inherits=entity_dict.get('inherits'),
context=entity_dict.get('context'),
rules=entity_dict.get('rules')
)

Expand All @@ -479,12 +489,13 @@ class Rule(Entity):

rule_counter = 0

def __init__(self, loader, id=None, cores=None, mem=None, gpus=None,
env=None, params=None, resubmit=None, tags=None, inherits=None, match=None, execute=None, fail=None):
def __init__(self, loader, id=None, cores=None, mem=None, gpus=None, env=None, params=None, resubmit=None,
tags=None, inherits=None, context=None, match=None, execute=None, fail=None):
if not id:
Rule.rule_counter += 1
id = f"tpv_rule_{Rule.rule_counter}"
super().__init__(loader, id, cores, mem, gpus, env, params, resubmit, tags, inherits=inherits)
super().__init__(loader, id=id, cores=cores, mem=mem, gpus=gpus, env=env, params=params, resubmit=resubmit,
tags=tags, context=context, inherits=inherits)
self.match = match
self.execute = execute
self.fail = fail
Expand All @@ -508,6 +519,7 @@ def from_dict(loader, entity_dict):
resubmit=entity_dict.get('resubmit'),
tags=entity_dict.get('scheduling'),
inherits=entity_dict.get('inherits'),
context=entity_dict.get('context'),
# TODO: Remove deprecated match clause in future
match=entity_dict.get('if') or entity_dict.get('match'),
execute=entity_dict.get('execute'),
Expand Down

0 comments on commit b2a53ee

Please sign in to comment.