From cd80edbc716c0379660c64c5fddf415186f8054f Mon Sep 17 00:00:00 2001 From: Matthew Dailis Date: Sat, 13 Jul 2024 13:19:46 -0700 Subject: [PATCH] Improve ergonomics of call and spawn --- pymerlin/_internal/_context.py | 8 +++-- pymerlin/_internal/_decorators.py | 41 ++++++++++++++++++----- pymerlin/_internal/_directive_type.py | 11 ++---- pymerlin/_internal/_globals.py | 2 +- pymerlin/_internal/_model_type.py | 5 +-- pymerlin/_internal/_spawn_helpers.py | 16 ++++----- pymerlin/_internal/_task_specification.py | 34 +++++++------------ pymerlin/_internal/_threaded_task.py | 13 ++++--- pymerlin/model_actions.py | 30 +++++++---------- tests/test_simulation.py | 6 ++-- 10 files changed, 89 insertions(+), 77 deletions(-) diff --git a/pymerlin/_internal/_context.py b/pymerlin/_internal/_context.py index b1a380f..5262ee6 100644 --- a/pymerlin/_internal/_context.py +++ b/pymerlin/_internal/_context.py @@ -5,16 +5,17 @@ @contextmanager -def _context(scheduler, spawner=None, reaction_context=None): - _set_context(scheduler, spawner, reaction_context) +def _context(scheduler, spawner=None, reaction_context=None, model_type=None): + _set_context(scheduler, spawner, reaction_context, model_type=None) yield _clear_context() -def _set_context(context, spawner, reaction_context): +def _set_context(context, spawner, reaction_context, model_type): _current_context.clear() _current_context.append(context) _current_context.append(spawner) + _current_context.append(model_type) _globals.reaction_context = reaction_context @@ -22,4 +23,5 @@ def _clear_context(): _current_context.clear() _current_context.append(None) _current_context.append(None) + _current_context.append(None) _globals.reaction_context = None diff --git a/pymerlin/_internal/_decorators.py b/pymerlin/_internal/_decorators.py index 0591d98..e391851 100644 --- a/pymerlin/_internal/_decorators.py +++ b/pymerlin/_internal/_decorators.py @@ -6,6 +6,10 @@ import warnings from dataclasses import dataclass +from pymerlin._internal._serialized_value import from_map_str_serialized_value +from pymerlin._internal._spawn_helpers import activity_wrapper, get_topics +from pymerlin._internal._task_specification import TaskInstance + def MissionModel(cls): """ @@ -19,7 +23,12 @@ def MissionModel(cls): cls.activity_types = {} def ActivityType(func): - activity_definition = wrap(func) + if type(func) == TaskDefinition: + activity_definition = func + elif callable(func): + activity_definition = TaskDefinition(func.__name__, lambda *args, **kwargs: activity_wrapper(TaskDefinition("inner", func), args, kwargs, *get_topics(activity_definition))) + else: + raise ValueError("Cannot decorate " + repr(func) + " with @ActivityType") if activity_definition.name in cls.activity_types: warnings.warn("Re-defining activity type: " + activity_definition.name) cls.activity_types[activity_definition.name] = activity_definition @@ -29,7 +38,7 @@ def ActivityType(func): def Task(func): - return TaskDefinition(func) + return TaskDefinition(func.__name__, func) def Validation(validator, message=None): @@ -52,16 +61,32 @@ class ValidationResult: class TaskDefinition: - def __init__(self, inner): - self.inner = inner - self.name = inner.__name__ + """ + TaskDefinition can produce a TaskInstance given all of the arguments for that task + """ + def __init__(self, name, func): + self.name = name + self.inner = func self.validations = [] def add_validation(self, validation): self.validations.insert(0, validation) - def run_task_definition(self, *args, **kwargs): - return self.inner.__call__(*args, **kwargs) + def __call__(self, *args, **kwargs): + return self.make_instance(*args, **kwargs) + + def make_instance(self, *args, **kwargs) -> TaskInstance: + # inspect.getfullargspec(self.inner) + # return self.inner.__call__(*args, **kwargs) + return TaskInstance(lambda: self.inner.__call__(*args, **kwargs)) + # , f"{self.name}({', '.join(f'{k}={v}' for k, v in kwargs.items())})" + + def get_task_factory(self, model, args, gateway, model_type): + from pymerlin._internal._task_factory import TaskFactory + from pymerlin._internal._threaded_task import ThreadedTaskHost + + # It is expected that the first argument to an activity be the mission model + return TaskFactory(lambda: ThreadedTaskHost(gateway, model_type, self.make_instance(model, **from_map_str_serialized_value(gateway, args)))) def wrap(x): @@ -71,5 +96,5 @@ def wrap(x): if type(x) == TaskDefinition: return x if callable(x): - return TaskDefinition(x) + return TaskDefinition(x.__name__, x) raise Exception("Unhandled variant: " + str(type(x))) \ No newline at end of file diff --git a/pymerlin/_internal/_directive_type.py b/pymerlin/_internal/_directive_type.py index d2fd064..749affc 100644 --- a/pymerlin/_internal/_directive_type.py +++ b/pymerlin/_internal/_directive_type.py @@ -1,20 +1,16 @@ from pymerlin._internal._decorators import TaskDefinition from pymerlin._internal._globals import models_by_id from pymerlin._internal._input_type import InputType -from pymerlin._internal._serialized_value import from_map_str_serialized_value -from pymerlin._internal._spawn_helpers import activity_wrapper -from pymerlin._internal._task_factory import TaskFactory -from pymerlin._internal._threaded_task import ThreadedTaskHost - class DirectiveType: - def __init__(self, gateway, activity, input_topic, output_topic): + def __init__(self, gateway, activity, input_topic, output_topic, model_type): if type(activity) is not TaskDefinition: raise ValueError("Activity must be of type TaskDefinition, but was: " + repr(activity)) self.gateway = gateway self.activity = activity self.input_topic = input_topic self.output_topic = output_topic + self.model_type = model_type def getInputType(self): return InputType() @@ -23,8 +19,7 @@ def getOutputType(self): return None def getTaskFactory(self, model_id, args): - task_provider = TaskDefinition(lambda: activity_wrapper(self.activity, from_map_str_serialized_value(self.gateway, args), models_by_id[model_id][0], self.input_topic, self.output_topic)) - return TaskFactory(lambda: ThreadedTaskHost(self.gateway, models_by_id[model_id][1], task_provider)) + return self.activity.get_task_factory(models_by_id[model_id][0], args, self.gateway, self.model_type) class Java: implements = ["gov.nasa.jpl.aerie.merlin.protocol.model.DirectiveType"] diff --git a/pymerlin/_internal/_globals.py b/pymerlin/_internal/_globals.py index dd4898d..253a33f 100644 --- a/pymerlin/_internal/_globals.py +++ b/pymerlin/_internal/_globals.py @@ -1,6 +1,6 @@ models_by_id = {} -_current_context = [None, None] +_current_context = [None, None, None] next_cell_id = 0 diff --git a/pymerlin/_internal/_model_type.py b/pymerlin/_internal/_model_type.py index 7e55b44..2e5edaa 100644 --- a/pymerlin/_internal/_model_type.py +++ b/pymerlin/_internal/_model_type.py @@ -33,7 +33,7 @@ def spawn(coro): new_task = ThreadedTaskHost(self.gateway, self, coro) builder.daemon(TaskFactory(lambda: new_task)) - with _context(None, spawner=spawn): + with _context(None, spawner=spawn, model_type=self): model = self.model_class(registrar) model._model_type = self @@ -66,7 +66,8 @@ def getDirectiveTypes(self): self.gateway, activity_type[0], # TaskDefinition activity_type[1], # input_topic - activity_type[2]) # output_topic + activity_type[2], # output_topic + self) # model type for activity_type in self.activity_types }, self.gateway._gateway_client) diff --git a/pymerlin/_internal/_spawn_helpers.py b/pymerlin/_internal/_spawn_helpers.py index 5b55d22..ebb764f 100644 --- a/pymerlin/_internal/_spawn_helpers.py +++ b/pymerlin/_internal/_spawn_helpers.py @@ -1,5 +1,4 @@ from pymerlin._internal import _globals -from pymerlin._internal._decorators import TaskDefinition # async def activity_wrapper(task, args, model, input_topic, output_topic): @@ -12,18 +11,19 @@ # if output_topic is not None: # _globals._current_context[0].emit({}, output_topic) -def activity_wrapper(task, args, model, input_topic, output_topic): +def activity_wrapper(task, args, kwargs, input_topic, output_topic): + from pymerlin._internal._decorators import TaskDefinition if type(task) is not TaskDefinition: raise ValueError("Hmm, why? " + repr(task)) - if input_topic is not None: - _globals._current_context[0].emit(args, input_topic) - task.run_task_definition(model, **args) - if output_topic is not None: - _globals._current_context[0].emit({}, output_topic) + _globals._current_context[0].emit({}, input_topic) + task.make_instance(*args, **kwargs).run() + _globals._current_context[0].emit({}, output_topic) -def get_topics(model_type, func): +def get_topics(func): + from pymerlin._internal._decorators import TaskDefinition if type(func) is not TaskDefinition: raise Exception("Whoa there buddy") + model_type = _globals._current_context[2] for activity_func, input_topic, output_topic in model_type.activity_types: if activity_func is func: return input_topic, output_topic diff --git a/pymerlin/_internal/_task_specification.py b/pymerlin/_internal/_task_specification.py index d149751..9871c72 100644 --- a/pymerlin/_internal/_task_specification.py +++ b/pymerlin/_internal/_task_specification.py @@ -1,26 +1,18 @@ class TaskInstance: - def __init__(self, func, kwargs, model, validations, definition): + """ + A TaskInstance is just a lambda with extra steps + """ + def __init__(self, func): self.func = func - self.args = kwargs - self.model = model - self.validations = validations # - self.definition = definition - # self.kwargs = kwargs - def instantiate(self): - if self.model is None: - return self.func(**self.args) #, **self.kwargs) - else: - return self.func(self.model, **self.args) # , **self.kwargs) + # def validate(self): + # return [ + # validation(self.args) + # for validation in self.validations + # ] - def validate(self): - return [ - validation(self.args) - for validation in self.validations - ] + # def __repr__(self): + # return self.repr - def __repr__(self): - return f"{self.definition.name}({', '.join(f'{k}={v}' for k, v in self.args.items())})" - - def __call__(self, *args, **kwargs): - return self.instantiate() + def run(self): + return self.func() diff --git a/pymerlin/_internal/_threaded_task.py b/pymerlin/_internal/_threaded_task.py index 021ff32..0562b12 100644 --- a/pymerlin/_internal/_threaded_task.py +++ b/pymerlin/_internal/_threaded_task.py @@ -5,8 +5,8 @@ from pymerlin._internal import _globals from pymerlin._internal._condition import Condition from pymerlin._internal._context import _set_context, _clear_context -from pymerlin._internal._decorators import TaskDefinition from pymerlin._internal._task_factory import TaskFactory +from pymerlin._internal._task_specification import TaskInstance from pymerlin._internal._task_status import Completed, Delayed, Awaiting, Calling # Host-to-task message types @@ -22,8 +22,11 @@ class ThreadedTaskHost: def __init__(self, gateway, model_type, task_provider): - if type(task_provider) is not TaskDefinition: + if type(task_provider) is not TaskInstance: raise ValueError(repr(task_provider)) + from pymerlin._internal._model_type import ModelType + if type(model_type) is not ModelType: + raise ValueError(repr(model_type)) self.gateway = gateway self.task_thread = _ThreadedTask(task_provider, model_type, gateway) self.task_thread_started = False @@ -81,8 +84,8 @@ def _spawn(self, task_provider): def _run(self, scheduler): try: - _set_context(scheduler, self._spawn, self) - result = self.task.run_task_definition() + _set_context(scheduler, self._spawn, self, self.model_type) + result = self.task.run() self.outbox.put(Finished(result)) except TaskAbort: self.outbox.put(Aborted()) @@ -97,7 +100,7 @@ def yield_with(self, status): self.outbox.put(Yield(status)) request = self.inbox.get() if type(request) == Resume: - _set_context(request.scheduler, self._spawn, self) + _set_context(request.scheduler, self._spawn, self, self.model_type) return elif type(request) == Abort: self.aborting = True diff --git a/pymerlin/model_actions.py b/pymerlin/model_actions.py index 49d33a0..7a5824e 100644 --- a/pymerlin/model_actions.py +++ b/pymerlin/model_actions.py @@ -6,8 +6,8 @@ import pymerlin._internal._task_status import pymerlin.duration import pymerlin._internal._globals -from pymerlin._internal._decorators import TaskDefinition from pymerlin._internal._spawn_helpers import activity_wrapper, get_topics +from pymerlin._internal._task_specification import TaskInstance def delay(duration): @@ -18,18 +18,12 @@ def delay(duration): return _yield_with(pymerlin._internal._task_status.Delayed(duration)) -def spawn_activity(model, child, args): +def spawn_activity(child): """ :param coro: :return: """ - topics = get_topics(model._model_type, child) - task_provider = TaskDefinition(lambda: activity_wrapper( - child, - args, - model, - *topics)) - pymerlin._internal._globals._current_context[1](task_provider) + pymerlin._internal._globals._current_context[1](child) def spawn_task(child, args): @@ -37,18 +31,18 @@ def spawn_task(child, args): :param coro: :return: """ - pymerlin._internal._globals._current_context[1](TaskDefinition(lambda: child.run_task_definition(**args))) + pymerlin._internal._globals._current_context[1](child.make_instance(**args)) -def call(model, child, args): - if type(child) is not TaskDefinition: +def call(child): + if type(child) is not TaskInstance: raise ValueError("Should be TaskDefinition, was: " + repr(child)) - task_provider = TaskDefinition(lambda: activity_wrapper( - child, - args, - model, - *get_topics(model._model_type, child))) - return _yield_with(pymerlin._internal._task_status.Calling(task_provider)) + # task_provider = TaskInstance(lambda: activity_wrapper( + # child, + # args, + # model, + # *get_topics(model._model_type, child))) + return _yield_with(pymerlin._internal._task_status.Calling(child)) def wait_until(condition): diff --git a/tests/test_simulation.py b/tests/test_simulation.py index b398a92..58f4e80 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -118,7 +118,7 @@ def test_spawn_activity(): @TestMissionModel.ActivityType def activity(mission: TestMissionModel): mission.counter.set(123) - spawn_activity(mission, other_activity, {}) + spawn_activity(other_activity(mission)) mission.counter.set(345) assert mission.counter.get() == 345 @@ -168,7 +168,7 @@ def test_call(): @TestMissionModel.ActivityType def activity(mission: TestMissionModel): mission.counter.set(123) - call(mission, other_activity, {}) + call(other_activity(mission)) assert mission.counter.get() == 345 delay("00:00:01") @@ -193,7 +193,7 @@ def test_call_task(): @TestMissionModel.ActivityType def activity(mission: TestMissionModel): mission.counter.set(123) - call(mission, subtask, {}) + call(subtask(mission)) assert mission.counter.get() == 345 delay("00:00:01")