Skip to content

Commit

Permalink
Adopt with new message type and solve import issues
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 20, 2024
1 parent 92f1683 commit 6085bca
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ def process_kill(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
msg_text = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, msg_text=msg_text)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -371,8 +371,8 @@ def process_pause(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
msg_text = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, msg_text=msg_text)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down
19 changes: 11 additions & 8 deletions src/aiida/engine/processes/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import collections
import concurrent
import functools
import typing as t

import kiwipy
Expand All @@ -18,7 +19,7 @@
from aiida.orm import ProcessNode, QueryBuilder
from aiida.tools.query.calculation import CalculationQueryBuilder

LOGGER = AIIDA_LOGGER.getChild('process_control')
LOGGER = AIIDA_LOGGER.getChild('engine.processes')


class ProcessTimeoutException(AiidaException):
Expand Down Expand Up @@ -135,7 +136,7 @@ def play_processes(
def pause_processes(
processes: list[ProcessNode] | None = None,
*,
message: str = 'Paused through `aiida.engine.processes.control.pause_processes`',
msg_text: str = 'Paused through `aiida.engine.processes.control.pause_processes`',
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -164,13 +165,14 @@ def pause_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(processes, controller.pause_process, 'pause', 'pausing', timeout, wait, msg=message)
action = functools.partial(controller.pause_process, msg_text=msg_text)
_perform_actions(processes, action, 'pause', 'pausing', timeout, wait)


def kill_processes(
processes: list[ProcessNode] | None = None,
*,
message: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
msg_text: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -199,7 +201,8 @@ def kill_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message)
action = functools.partial(controller.kill_process, msg_text=msg_text)
_perform_actions(processes, action, 'kill', 'killing', timeout, wait)


def _perform_actions(
Expand Down Expand Up @@ -281,9 +284,9 @@ def handle_result(result):
unwrapped = unwrap_kiwi_future(future)
result = unwrapped.result()
except communications.TimeoutError:
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out')
LOGGER.error(f'call to {infinitive} Process<{process.pk}> timed out', exc_info=True)
except Exception as exception:
LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}')
LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}', exc_info=True)
else:
if isinstance(result, kiwipy.Future):
LOGGER.report(f'scheduled {infinitive} Process<{process.pk}>')
Expand All @@ -302,7 +305,7 @@ def handle_result(result):
try:
result = future.result()
except Exception as exception:
LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}')
LOGGER.error(f'failed to {infinitive} Process<{process.pk}>: {exception}', exc_info=True)
else:
handle_result(result)

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode
def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
LOGGER.critical('runner received interrupt, killing process %s', process.pid)
result = process.kill(msg='Process was killed because the runner received an interrupt')
result = process.kill(msg_text='Process was killed because the runner received an interrupt')
return result

# Store the current handler on the signal such that it can be restored after process has terminated
Expand Down
15 changes: 9 additions & 6 deletions src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
import plumpy.futures
import plumpy.persistence
import plumpy.processes

# from kiwipy.communications import UnroutableError
from plumpy.process_states import Finished, ProcessState

# from plumpy.processes import ConnectionClosed # type: ignore[attr-defined]
from plumpy.processes import Process as PlumpyProcess
from plumpy.utils import AttributesFrozendict
Expand Down Expand Up @@ -318,7 +320,7 @@ def load_instance_state(
else:
self._runner = manager.get_manager().get_runner()

load_context = load_context.copyextend(loop=self._runner.loop, communicator=self._runner.communicator)
load_context = load_context.copyextend(loop=self._runner.loop, coordinator=self._runner.communicator)
super().load_instance_state(saved_state, load_context)

if self.SaveKeys.CALC_ID.value in saved_state:
Expand All @@ -329,7 +331,7 @@ def load_instance_state(

self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')

def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]:
def kill(self, msg_text: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]:
"""Kill the process and all the children calculations it called
:param msg: message
Expand All @@ -338,7 +340,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur

had_been_terminated = self.has_terminated()

result = super().kill(msg)
result = super().kill(msg_text)

# Only kill children if we could be killed ourselves
if result is not False and not had_been_terminated:
Expand All @@ -348,10 +350,11 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur
self.logger.info('no controller available to kill child<%s>', child.pk)
continue
try:
result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>')
result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
result = asyncio.wrap_future(result) # type: ignore[arg-type]
if asyncio.isfuture(result):
killing.append(result)
# FIXME: use generic exception to catch the coordinator side exception
# except ConnectionClosed:
# self.logger.info('no connection available to kill child<%s>', child.pk)
# except UnroutableError:
Expand All @@ -365,10 +368,10 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur

if killing:
# We are waiting for things to be killed, so return the 'gathered' future
kill_future = plumpy.futures.gather(*killing)
kill_future = asyncio.gather(*killing)
result = self.loop.create_future()

def done(done_future: plumpy.futures.Future):
def done(done_future: asyncio.Future):
is_all_killed = all(done_future.result())
result.set_result(is_all_killed)

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def kill_process(_num, _frame):
LOGGER.warning('runner received interrupt, process %s already being killed', process_inited.pid)
return
LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid)
process_inited.kill(msg='Process was killed because the runner received an interrupt')
process_inited.kill(msg_text='Process was killed because the runner received an interrupt')

original_handler_int = signal.getsignal(signal.SIGINT)
original_handler_term = signal.getsignal(signal.SIGTERM)
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,10 @@ def get_process_controller(self) -> 'RemoteProcessThreadController':
:return: the process controller instance
"""
from plumpy.process_comms import RemoteProcessThreadController
from plumpy.rmq import RemoteProcessThreadController

if self._process_controller is None:
# FIXME: use coordinator wrapper
self._process_controller = RemoteProcessThreadController(self.get_communicator())

return self._process_controller
Expand Down
1 change: 1 addition & 0 deletions tests/engine/processes/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_kill_processes(submit_and_await):
node = submit_and_await(WaitProcess, ProcessState.WAITING)

control.kill_processes([node], wait=True)
# __import__('ipdb').set_trace()
assert node.is_terminated
assert node.is_killed
assert node.process_status == 'Killed through `aiida.engine.processes.control.kill_processes`'
Expand Down
8 changes: 4 additions & 4 deletions tests/engine/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def do_pause():
assert calc_node.paused

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text=kill_message)
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand All @@ -112,7 +112,7 @@ async def do_pause_play():
await asyncio.sleep(0.1)

pause_message = 'Take a seat'
pause_future = controller.pause_process(calc_node.pk, msg=pause_message)
pause_future = controller.pause_process(calc_node.pk, msg_text=pause_message)
future = await with_timeout(asyncio.wrap_future(pause_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert calc_node.paused
Expand All @@ -127,7 +127,7 @@ async def do_pause_play():
assert calc_node.process_status is None

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text=kill_message)
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand All @@ -145,7 +145,7 @@ async def do_kill():
await asyncio.sleep(0.1)

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text=kill_message)
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand Down
2 changes: 1 addition & 1 deletion tests/engine/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_call_on_process_finish(runner):
"""Test call on calculation finish."""
loop = runner.loop
proc = Proc(runner=runner, inputs={'a': Str('input')})
future = plumpy.Future()
future = asyncio.Future()
event = threading.Event()

def calc_done():
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def run_until_paused(proc):
"""Set up a future that will be resolved when process is paused"""
listener = plumpy.ProcessListener()
paused = plumpy.Future()
paused = asyncio.Future()

if proc.paused:
paused.set_result(True)
Expand All @@ -49,7 +49,7 @@ def run_until_waiting(proc):
from aiida.engine import ProcessState

listener = plumpy.ProcessListener()
in_waiting = plumpy.Future()
in_waiting = asyncio.Future()

if proc.state == ProcessState.WAITING:
in_waiting.set_result(True)
Expand Down

0 comments on commit 6085bca

Please sign in to comment.