Skip to content

Commit

Permalink
Unit tests, add support for generator functions with returns
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Mar 7, 2024
1 parent c77992c commit b10b275
Show file tree
Hide file tree
Showing 7 changed files with 544 additions and 45 deletions.
3 changes: 3 additions & 0 deletions colmena/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from colmena.models.results import Result, ResourceRequirements, FailureInformation, SerializationMethod

__all__ = ['Result', 'ResourceRequirements', 'FailureInformation', 'ExecutableTask', 'SerializationMethod']
488 changes: 488 additions & 0 deletions colmena/models/results.py

Large diffs are not rendered by default.

58 changes: 46 additions & 12 deletions colmena/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from tempfile import TemporaryDirectory
from time import perf_counter
from inspect import signature, isgeneratorfunction
from typing import Any, Dict, List, Tuple, Optional, Callable, Iterable
from typing import Any, Dict, List, Tuple, Optional, Callable, Generator

from .results import ResourceRequirements, Result, FailureInformation
from colmena.models.results import ResourceRequirements, Result, FailureInformation
from colmena.proxy import resolve_proxies_async, store_proxy_stats
from colmena.queue import ColmenaQueues

Expand All @@ -33,11 +33,12 @@ def function(self, *args, **kwargs) -> Any:
"""Function provided by the Colmena user"""
raise NotImplementedError()

def __call__(self, result: Result, queues: Optional[ColmenaQueues]) -> Result:
def __call__(self, result: Result, queues: Optional[ColmenaQueues] = None) -> Result:
"""Invoke a Colmena task request
Args:
result: Request, which inclues the arguments and will hold the result
queues: Queues used to send intermediate results back [Not Yet Used]
Returns:
The input result object, populated with the results
"""
Expand Down Expand Up @@ -100,7 +101,12 @@ def __call__(self, result: Result, queues: Optional[ColmenaQueues]) -> Result:


class PythonTask(ColmenaTask):
"""A Python function to be executed on a single worker of a larger workflow"""
"""A Python function to be executed on a single worker
Args:
function: Generator function to be executed
name: Name of the function. Defaults to `function.__name__`
"""

function: Callable

Expand All @@ -112,21 +118,39 @@ def __init__(self, function: Callable, name: Optional[str] = None) -> None:


class PythonGeneratorTask(ColmenaTask):
"""Python task which generates intermediate results"""
"""Python task which runs on a single worker and generates results iteratively
def __init__(self, function: Callable[..., Iterable], name: Optional[str] = None) -> None:
Args:
function: Generator function to be executed
name: Name of the function. Defaults to `function.__name__`
store_return_value: Whether to capture the `return value <https://docs.python.org/3/reference/simple_stmts.html#the-return-statement>`_
of the generator and store it in the Result object.
"""

def __init__(self,
function: Callable[..., Generator],
name: Optional[str] = None,
store_return_value: bool = False) -> None:
if not isgeneratorfunction(function):
raise ValueError('Function is not a generator function. Use `PythonTask` instead.')
self._function = function
self.name = name or function.__name__
self.store_return_value = store_return_value

def function(self, *args, **kwargs) -> Any:
"""Run the Colmena task and collect intermediate results to provide as a list"""

# TODO (wardlt): Have the function push intemediate results back to a function queue
return [
result for result in self._function(*args, **kwargs)
]
gen = self._function(*args, **kwargs)
iter_results = []
while True:
try:
iter_results.append(next(gen))
except StopIteration as e:
if self.store_return_value:
return iter_results, e.value
else:
return iter_results


class ExecutableTask(ColmenaTask):
Expand Down Expand Up @@ -160,6 +184,12 @@ class ExecutableTask(ColmenaTask):
The attributes of this class (e.g., ``node_count``, ``total_ranks``) will be used as arguments to `format`.
For example, a template of ``aprun -N {total_ranks} -n {cpu_process}`` will produce ``aprun -N 6 -n 3`` if you
specify ``node_count=2`` and ``cpu_processes=3``.
Args:
executable: List of executable arguments
name: Name used for the task. Defaults to ``executable[0]``
mpi: Whether to use MPI to launch the exectuable
mpi_command_string: Template for MPI launcher. See :attr:`mpi_command_string`.
"""

executable: List[str]
Expand All @@ -173,9 +203,13 @@ class ExecutableTask(ColmenaTask):
Should include placeholders named after the fields in ResourceRequirements marked using {}'s.
Example: `mpirun -np {total_ranks}`"""

@property
def __name__(self):
return self.__class__.__name__.lower()
def __init__(self, executable: List[str], name: Optional[str] = None,
mpi: bool = False, mpi_command_string: Optional[str] = None) -> None:
super().__init__()
self.name = name or executable[0]
self.executable = executable
self.mpi = mpi
self.mpi_command_string = mpi_command_string

def render_mpi_launch(self, resources: ResourceRequirements) -> str:
"""Create an MPI launch command given the configuration
Expand Down
4 changes: 2 additions & 2 deletions colmena/queue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Implementations of the task and result queues from Colmena"""

from .base import ColmenaQueues # noqa: 401
from .python import PipeQueues # noqa: 401
from colmena.queue.base import ColmenaQueues # noqa: 401
from colmena.queue.python import PipeQueues # noqa: 401
3 changes: 2 additions & 1 deletion colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from parsl.app.bash import BashApp
from parsl.config import Config
from parsl.app.python import PythonApp
from colmena.models.tasks import ExecutableTask

from colmena.queue.base import ColmenaQueues
from colmena.models import Result, ExecutableTask, FailureInformation, ResourceRequirements
from colmena.models import Result, FailureInformation, ResourceRequirements
from colmena.proxy import resolve_proxies_async
from colmena.task_server.base import run_and_record_timing, FutureBasedTaskServer

Expand Down
3 changes: 2 additions & 1 deletion colmena/task_server/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from proxystore.store import unregister_store
from pytest import fixture

from colmena.models import Result, ExecutableTask, SerializationMethod
from colmena.models import Result, SerializationMethod
from colmena.models.tasks import ExecutableTask
from colmena.task_server.base import run_and_record_timing


Expand Down
30 changes: 1 addition & 29 deletions colmena/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
"""Tests for the data models"""
import sys
from typing import Any, Tuple, Dict, List, Optional
from pathlib import Path

from colmena.models import ResourceRequirements, ExecutableTask, Result


class EchoTask(ExecutableTask):
def __init__(self):
super().__init__(executable=['echo'])

def preprocess(self, run_dir: Path, args: Tuple[Any], kwargs: Dict[str, Any]) -> Tuple[List[str], Optional[str]]:
return list(map(str, args)), None

def postprocess(self, run_dir: Path) -> Any:
return (run_dir / 'colmena.stdout').read_text()
from colmena.models import ResourceRequirements, Result


def test_resources():
Expand All @@ -41,18 +28,3 @@ def test_message_sizes():
result.serialize()
assert result.message_sizes['inputs'] >= 2 * sys.getsizeof('0' * 8)
assert result.message_sizes['inputs'] >= sys.getsizeof(1)


def test_executable_task():
# Run a basic tak
task = EchoTask()
assert task.executable == ['echo']
assert task(1) == '1\n'

# Run an "MPI task"
task.mpi = True
task.mpi_command_string = 'aprun -N {total_ranks} -n {cpu_processes} --cc depth'
assert task.render_mpi_launch(ResourceRequirements(node_count=2, cpu_processes=4)) == 'aprun -N 8 -n 4 --cc depth'

task.mpi_command_string = 'echo -N {total_ranks} -n {cpu_processes} --cc depth'
assert task(1, _resources=ResourceRequirements(node_count=2, cpu_processes=3)) == '-N 6 -n 3 --cc depth echo 1\n'

0 comments on commit b10b275

Please sign in to comment.