From c769906e81246d6398161bdc931f9d90185a2ec1 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 21 Dec 2024 14:30:30 +0100 Subject: [PATCH] Controller snuck into broker --- environment.yml | 2 +- src/aiida/brokers/broker.py | 9 ++++++++- src/aiida/brokers/rabbitmq/broker.py | 9 ++++++++- src/aiida/engine/processes/process.py | 1 - src/aiida/engine/runners.py | 20 ++++++++++++-------- src/aiida/manage/manager.py | 16 +++++++--------- 6 files changed, 36 insertions(+), 21 deletions(-) diff --git a/environment.yml b/environment.yml index ad80dd341..d38ea2a43 100644 --- a/environment.yml +++ b/environment.yml @@ -22,7 +22,7 @@ dependencies: - importlib-metadata~=6.0 - numpy~=1.21 - paramiko~=3.0 -- plumpy~=0.22.3 +- plumpy - pgsu~=0.3.0 - psutil~=5.6 - psycopg[binary]~=3.0 diff --git a/src/aiida/brokers/broker.py b/src/aiida/brokers/broker.py index 941c69833..1259a9c43 100644 --- a/src/aiida/brokers/broker.py +++ b/src/aiida/brokers/broker.py @@ -3,11 +3,13 @@ import abc import typing as t +from plumpy.controller import ProcessController if t.TYPE_CHECKING: - from aiida.manage.configuration.profile import Profile from plumpy.coordinator import Coordinator + from aiida.manage.configuration.profile import Profile + __all__ = ('Broker',) @@ -25,6 +27,11 @@ def __init__(self, profile: 'Profile') -> None: def get_coordinator(self) -> 'Coordinator': """Return an instance of coordinator.""" + @abc.abstractmethod + def get_controller(self) -> ProcessController: + """Return the process controller""" + ... + @abc.abstractmethod def iterate_tasks(self): """Return an iterator over the tasks in the launch queue.""" diff --git a/src/aiida/brokers/rabbitmq/broker.py b/src/aiida/brokers/rabbitmq/broker.py index 370afc6ac..0ed8bcd0d 100644 --- a/src/aiida/brokers/rabbitmq/broker.py +++ b/src/aiida/brokers/rabbitmq/broker.py @@ -5,7 +5,9 @@ import functools import typing as t -from plumpy.rmq import RmqCoordinator +from plumpy.rmq import RemoteProcessThreadController, RmqCoordinator +from plumpy import ProcessController +from plumpy.rmq.process_control import RemoteProcessController from aiida.brokers.broker import Broker from aiida.common.log import AIIDA_LOGGER @@ -15,6 +17,7 @@ if t.TYPE_CHECKING: from kiwipy.rmq import RmqThreadCommunicator + from aiida.manage.configuration.profile import Profile LOGGER = AIIDA_LOGGER.getChild('broker.rabbitmq') @@ -61,6 +64,10 @@ def get_coordinator(self): return coordinator + def get_controller(self) -> ProcessController: + coordinator = self.get_coordinator() + return RemoteProcessThreadController(coordinator) + def _create_communicator(self) -> 'RmqThreadCommunicator': """Return an instance of :class:`kiwipy.Communicator`.""" from kiwipy.rmq import RmqThreadCommunicator diff --git a/src/aiida/engine/processes/process.py b/src/aiida/engine/processes/process.py index a678b115c..1746cee93 100644 --- a/src/aiida/engine/processes/process.py +++ b/src/aiida/engine/processes/process.py @@ -43,7 +43,6 @@ # from kiwipy.communications import UnroutableError # from plumpy.processes import ConnectionClosed # type: ignore[attr-defined] from plumpy.process_states import Finished, ProcessState - from plumpy.processes import Process as PlumpyProcess from plumpy.utils import AttributesFrozendict diff --git a/src/aiida/engine/runners.py b/src/aiida/engine/runners.py index e1dd3c38f..92a62c071 100644 --- a/src/aiida/engine/runners.py +++ b/src/aiida/engine/runners.py @@ -27,6 +27,7 @@ from aiida.common import exceptions from aiida.orm import ProcessNode, load_node from aiida.plugins.utils import PluginVersionProvider +from aiida.brokers import Broker from . import transports, utils from .processes import Process, ProcessBuilder, ProcessState, futures @@ -64,7 +65,7 @@ def __init__( self, poll_interval: Union[int, float] = 0, loop: Optional[asyncio.AbstractEventLoop] = None, - coordinator: Optional[Coordinator] = None, + broker: Broker | None = None, broker_submit: bool = False, persister: Optional[Persister] = None, ): @@ -72,14 +73,14 @@ def __init__( :param poll_interval: interval in seconds between polling for status of active sub processes :param loop: an asyncio event loop, if none is suppled a new one will be created - :param coordinator: the coordinator to use + :param broker: the broker to use :param broker_submit: if True, processes will be submitted to the broker, otherwise they will be scheduled here :param persister: the persister to use to persist processes """ assert not ( broker_submit and persister is None - ), 'Must supply a persister if you want to submit using coordinator' + ), 'Must supply a persister if you want to submit using coordinator/broker' set_event_loop_policy() self._loop = loop or asyncio.get_event_loop() @@ -90,11 +91,14 @@ def __init__( self._persister = persister self._plugin_version_provider = PluginVersionProvider() - if coordinator is not None: - # FIXME: the wrap is not needed, when passed in, the coordinator should already wrapped - self._coordinator = wrap_communicator(coordinator.communicator, self._loop) - self._controller = RemoteProcessThreadController(coordinator) + # FIXME: broker and coordinator overlap the concept there for over-abstraction, remove the abstraction + if broker is not None: + _coordinator = broker.get_coordinator() + # FIXME: the wrap should not be needed + self._coordinator = wrap_communicator(_coordinator.communicator, self._loop) + self._controller = broker.get_controller() elif self._broker_submit: + # FIXME: if broker then broker_submit else False LOGGER.warning('Disabling broker submission, no coordinator provided') self._broker_submit = False @@ -350,7 +354,7 @@ def get_process_future(self, pk: int) -> futures.ProcessFuture: :return: A future representing the completion of the process node """ - return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._coordinator) + return futures.ProcessFuture(pk, self._loop, self._poll_interval, self.coordinator) def _poll_process(self, node, callback): """Check whether the process state of the node is terminated and call the callback or reschedule it. diff --git a/src/aiida/manage/manager.py b/src/aiida/manage/manager.py index 916589ccf..690051c73 100644 --- a/src/aiida/manage/manager.py +++ b/src/aiida/manage/manager.py @@ -10,14 +10,12 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING, Any, Optional, Union -import asyncio -import kiwipy from plumpy.coordinator import Coordinator if TYPE_CHECKING: - from kiwipy.rmq import RmqThreadCommunicator from plumpy.process_comms import RemoteProcessThreadController from aiida.brokers.broker import Broker @@ -169,7 +167,7 @@ def reset_profile_storage(self) -> None: self._profile_storage = None def reset_broker(self) -> None: - """Reset the communicator.""" + """Reset the broker.""" from concurrent import futures if self._broker is not None: @@ -401,7 +399,7 @@ def create_runner( self, poll_interval: Union[int, float] | None = None, loop: Optional[asyncio.AbstractEventLoop] = None, - coordinator: Optional[Coordinator] = None, + broker: Broker | None = None, broker_submit: bool = False, persister: Optional[AiiDAPersister] = None, ) -> 'Runner': @@ -422,19 +420,19 @@ def create_runner( _default_poll_interval = 0.0 if profile.is_test_profile else self.get_option('runner.poll.interval') _default_broker_submit = False - _default_coordinator = self.get_coordinator() _default_persister = self.get_persister() + _default_broker = self.get_broker() runner = runners.Runner( poll_interval=poll_interval or _default_poll_interval, loop=loop or asyncio.get_event_loop(), - coordinator=coordinator or _default_coordinator, + broker=broker or _default_broker, broker_submit=broker_submit or _default_broker_submit, persister=persister or _default_persister, ) return runner - def create_daemon_runner(self, loop: Optional['asyncio.AbstractEventLoop'] = None) -> 'Runner': + def create_daemon_runner(self) -> 'Runner': """Create and return a new daemon runner. This is used by workers when the daemon is running and in testing. @@ -449,7 +447,7 @@ def create_daemon_runner(self, loop: Optional['asyncio.AbstractEventLoop'] = Non from aiida.engine import persistence from aiida.engine.processes.launcher import ProcessLauncher - runner = self.create_runner(broker_submit=True, loop=loop) + runner = self.create_runner(broker_submit=True, loop=None) runner_loop = runner.loop # Listen for incoming launch requests