diff --git a/CHANGELOG.md b/CHANGELOG.md index c8884a316..f3152cfd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ 🔥 *Features* +* Added `max_retries` in `sdk.task` decorator. This allows users to restart ray workers on system crashes (like OOMKills or sigterms). Restarts do not happen with Python exceptions. + 🧟 *Deprecations* 👩‍🔬 *Experimental* diff --git a/src/orquestra/sdk/_base/_dsl.py b/src/orquestra/sdk/_base/_dsl.py index dc1ca3a9d..b4124d35b 100644 --- a/src/orquestra/sdk/_base/_dsl.py +++ b/src/orquestra/sdk/_base/_dsl.py @@ -515,6 +515,7 @@ def __init__( custom_image: Optional[str] = None, custom_name: Optional[str] = None, fn_ref: Optional[FunctionRef] = None, + max_retries: Optional[int] = None, ): if isinstance(fn, BuiltinFunctionType): raise NotImplementedError("Built-in functions are not supported as Tasks") @@ -531,6 +532,7 @@ def __init__( self._use_default_dependency_imports = dependency_imports is None self._source_import = source_import self._use_default_source_import = source_import is None + self._max_retries = max_retries # task itself is not part of any workflow yet. Don't pass wf defaults self._resolve_task_source_data() @@ -1093,6 +1095,7 @@ def task( n_outputs: Optional[int] = None, custom_image: Optional[str] = None, custom_name: Optional[str] = None, + max_retries: Optional[int] = None, ) -> Callable[[Callable[_P, _R]], TaskDef[_P, _R]]: ... @@ -1107,6 +1110,7 @@ def task( n_outputs: Optional[int] = None, custom_image: Optional[str] = None, custom_name: Optional[str] = None, + max_retries: Optional[int] = None, ) -> TaskDef[_P, _R]: ... @@ -1120,6 +1124,7 @@ def task( n_outputs: Optional[int] = None, custom_image: Optional[str] = None, custom_name: Optional[str] = None, + max_retries: Optional[int] = None, ) -> Union[TaskDef[_P, _R], Callable[[Callable[_P, _R]], TaskDef[_P, _R]]]: """Wraps a function into an Orquestra Task. @@ -1151,6 +1156,12 @@ def task( result of other task) - it will be placeholded. Every character that is non-alphanumeric will be changed to dash ("-"). Also only first 128 characters of the name will be used + max_retries: Maximum number of times a worker will try to retry after failure. + Useful if worker is killed by random events, or memory leaks from previously + executed tasks. + WARNING: retried workers might cause issues in MLflow logging, as retried + workers share the same invocation ID, MLflow identifier will be shared + between them. Raises: ValueError: when a task has fewer than 1 outputs. @@ -1188,6 +1199,7 @@ def _inner(fn: Callable[_P, _R]): output_metadata=output_metadata, custom_image=custom_image, custom_name=custom_name, + max_retries=max_retries, ) return task_def diff --git a/src/orquestra/sdk/_base/_traversal.py b/src/orquestra/sdk/_base/_traversal.py index f799e8b0d..bc0f78994 100644 --- a/src/orquestra/sdk/_base/_traversal.py +++ b/src/orquestra/sdk/_base/_traversal.py @@ -590,6 +590,7 @@ def _make_task_model( resources=resources, parameters=parameters, custom_image=task._custom_image, + max_retries=task._max_retries, ) diff --git a/src/orquestra/sdk/_ray/_build_workflow.py b/src/orquestra/sdk/_ray/_build_workflow.py index 7992110a7..5a5f99b2f 100644 --- a/src/orquestra/sdk/_ray/_build_workflow.py +++ b/src/orquestra/sdk/_ray/_build_workflow.py @@ -506,13 +506,16 @@ def make_ray_dag( # If there are any python packages to install for step - set runtime env "runtime_env": (_client.RuntimeEnv(pip=pip) if len(pip) > 0 else None), "catch_exceptions": False, - # We only want to execute workflow tasks once. This is so there is only one - # task run ID per task, for scenarios where this is used (like in MLFlow). + # We only want to execute workflow tasks once by default. + # This is so there is only one task run ID per task, for scenarios where + # this is used (like in MLflow). We allow setting this variable on + # task-level for some particular edge-cases like memory leaks inside + # 3rd party libraries - so in case of the OOMKilled worker it can be + # restarted. # By default, Ray will only retry tasks that fail due to a "system error". # For example, if the worker process crashes or exits early. # Normal Python exceptions are NOT retried. - # So, we turn max_retries down to 0. - "max_retries": 0, + "max_retries": user_task.max_retries if user_task.max_retries else 0, } # Non-custom task resources diff --git a/src/orquestra/sdk/schema/ir.py b/src/orquestra/sdk/schema/ir.py index 86b970d97..7bdc01ab1 100644 --- a/src/orquestra/sdk/schema/ir.py +++ b/src/orquestra/sdk/schema/ir.py @@ -238,6 +238,7 @@ class TaskDef(BaseModel): resources: t.Optional[Resources] = None + max_retries: t.Optional[int] = None # Hints the runtime to run this task in a docker container with this image. Has no # effect if the runtime doesn't support it. custom_image: t.Optional[str] = None diff --git a/tests/runtime/ray/test_integration.py b/tests/runtime/ray/test_integration.py index c88cc7ae2..5777843c7 100644 --- a/tests/runtime/ray/test_integration.py +++ b/tests/runtime/ray/test_integration.py @@ -1340,3 +1340,75 @@ def wf(): # Precondition wf_run = runtime.get_workflow_run_status(wf_run_id) assert wf_run.status.state == State.SUCCEEDED + + +@pytest.mark.slow +class TestRetries: + """ + Test that retrying Ray Workers works properly + """ + + @pytest.mark.parametrize( + "max_retries,should_fail", + [ + (1, False), # we should not fail with max_retries enabled + (50, False), # we should not fail with max_retries enabled + (0, True), # 0 means do not retry + (None, True), # We do not enable max_retries by default + ], + ) + def test_max_retries(self, runtime: _dag.RayRuntime, max_retries, should_fail): + @sdk.task(max_retries=max_retries) + def generic_task(*args): + if hasattr(sdk, "l"): + sdk.l.extend([0]) # type: ignore # noqa + else: + setattr(sdk, "l", [0]) + if len(sdk.l) == 2: # type: ignore # noqa + import os + import signal + + os.kill(os.getpid(), signal.SIGTERM) + + return None + + @sdk.workflow + def wf(): + task_res = None + for _ in range(5): + task_res = generic_task(task_res) + return task_res + + wf_model = wf().model + + # When + # The function-under-test is called inside the workflow. + wf_run_id = runtime.create_workflow_run(wf_model, project=None, dry_run=False) + + # we can't base our logic on SDK workflow status because of: + # https://zapatacomputing.atlassian.net/browse/ORQSDK-1024 + # We can just look into the message at peek that the workflow actually failed + # even tho we report is as RUNNING. + import ray.workflow + from ray.workflow.common import WorkflowStatus + + no_of_retries = 0 + + while True: + ray_status = ray.workflow.get_status(wf_run_id) + if no_of_retries >= 30: + break + if ray_status == WorkflowStatus.RUNNING: + time.sleep(1) + no_of_retries += 1 + continue + if ( + ray_status == WorkflowStatus.FAILED + or ray_status == WorkflowStatus.SUCCESSFUL + ): + break + + if should_fail: + assert ray_status == WorkflowStatus.FAILED + else: + assert ray_status == WorkflowStatus.SUCCESSFUL diff --git a/tests/sdk/test_dsl.py b/tests/sdk/test_dsl.py index 52be76f37..07319e494 100644 --- a/tests/sdk/test_dsl.py +++ b/tests/sdk/test_dsl.py @@ -545,6 +545,14 @@ def _local_task_1(x): assert len(warns.list) == 1 +def test_max_retries(): + @_dsl.task(max_retries=5) + def task(): + ... + + assert task._max_retries == 5 + + def test_default_import_type(monkeypatch): @_dsl.task def task(): diff --git a/tests/sdk/test_traversal.py b/tests/sdk/test_traversal.py index 28418c773..d4ca88b8f 100644 --- a/tests/sdk/test_traversal.py +++ b/tests/sdk/test_traversal.py @@ -974,6 +974,26 @@ def workflow(): assert list(workflow.model.task_invocations.keys())[0] == expected +@pytest.mark.parametrize( + "argument, expected", + [ + (None, None), + (1, 1), + (999, 999), + ], +) +def test_max_calls(argument, expected): + @_dsl.task(max_retries=argument) + def task(): + ... + + @_workflow.workflow() + def workflow(): + return task() + + assert list(workflow.model.tasks.values())[0].max_retries == expected + + class TestNumberOfFetchesOnInferRepos: @pytest.fixture() def setup_fetch(self, monkeypatch):