Skip to content

Commit

Permalink
Use a single stream redirector for both setup and predict (#2069)
Browse files Browse the repository at this point in the history
Currently for models that provide a non-async predict function we have
 two stream redirectors wrapping stdout/stderr.

We think this entering/exiting the stream redirector context multiple times
may have been causing unusual bugs, particularly with models that use
libraries like `tqdm` to render progress bars and other manipulation.

This PR changes this behaviour to use either the `StreamRedirector` or
(newly renamed) `SimpleStreamRedirector` based on whether the predict
is defined with an async function or not. We now use the same redirector
for module loading, setup and prediction.

The interrogation of the predict function now happens higher up the stack
in http.py when we read the input/output types using 
`config.get_predictor_types`. This now returns an additional value
`is_async` which is `True` when an async function is defined.

This is then passed down the stack into the worker which uses it to select an
appropriate stream redirector instance. In future when we come to support
async setup functions we may need to revisit whether `is_async` is the
correct term for the worker argument, `use_simple_stream` or similar might
be more appropriate.

This has been tested with both sync and async models including flux-dev 
and seems to be working correctly.
  • Loading branch information
aron authored Nov 29, 2024
1 parent 9cd4738 commit ed72ad5
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 101 deletions.
28 changes: 22 additions & 6 deletions python/cog/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
import os
import sys
import uuid
from typing import Optional, Tuple, Type
from typing import Any, Callable, Optional, Tuple, Type

import structlog
import yaml
Expand All @@ -16,7 +17,9 @@
from .predictor import (
get_input_type,
get_output_type,
get_predict,
get_predictor,
get_train,
get_training_input_type,
get_training_output_type,
load_full_predictor_from_file,
Expand Down Expand Up @@ -152,16 +155,29 @@ def get_predictor_ref(self, mode: Mode) -> str:

def get_predictor_types(
self, mode: Mode
) -> Tuple[Type[BaseInput], Type[BaseModel]]:
"""Find the input and output types of a predictor."""
) -> Tuple[Type[BaseInput], Type[BaseModel], bool]:
"""
Find the input & output types of a predictor/train function as well
as determining if the function is an async function.
"""
predictor_ref = self.get_predictor_ref(mode=mode)
predictor = self._load_predictor_for_types(
predictor_ref, _method_name_from_mode(mode=mode), mode
)

def is_async(fn: Callable[[Any], Any]) -> bool:
return inspect.iscoroutinefunction(fn) or inspect.isasyncgenfunction(fn)

if mode == Mode.PREDICT:
return get_input_type(predictor), get_output_type(predictor)
return (
get_input_type(predictor),
get_output_type(predictor),
is_async(get_predict(predictor)),
)
elif mode == Mode.TRAIN:
return get_training_input_type(predictor), get_training_output_type(
predictor
return (
get_training_input_type(predictor),
get_training_output_type(predictor),
is_async(get_train(predictor)),
)
raise ValueError(f"Mode {mode} not found for generating input/output types.")
12 changes: 6 additions & 6 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,19 @@ def original(self) -> TextIO:

if sys.version_info < (3, 9):

class _AsyncStreamRedirectorBase(contextlib.AbstractContextManager):
class _SimpleStreamRedirectorBase(contextlib.AbstractContextManager):
pass
else:

class _AsyncStreamRedirectorBase(
contextlib.AbstractContextManager["AsyncStreamRedirector"]
class _SimpleStreamRedirectorBase(
contextlib.AbstractContextManager["SimpleStreamRedirector"]
):
pass


class AsyncStreamRedirector(_AsyncStreamRedirectorBase):
class SimpleStreamRedirector(_SimpleStreamRedirectorBase):
"""
AsyncStreamRedirector is a context manager that redirects I/O streams to a
SimpleStreamRedirector is a context manager that redirects I/O streams to a
callback function. If `tee` is True, it also writes output to the original
streams.
Expand Down Expand Up @@ -179,7 +179,7 @@ def __exit__(
self._stderr_ctx.__exit__(exc_type, exc_value, traceback)

def drain(self, timeout: float = 0.0) -> None:
# Draining isn't complicated for AsyncStreamRedirector, since we're not
# Draining isn't complicated for SimpleStreamRedirector, since we're not
# moving data between threads. We just need to flush the streams.
sys.stdout.flush()
sys.stderr.flush()
Expand Down
10 changes: 7 additions & 3 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,17 @@ async def start_shutdown() -> Any:
return JSONResponse({}, status_code=200)

try:
InputType, OutputType = cog_config.get_predictor_types(mode=Mode.PREDICT)
InputType, OutputType, is_async = cog_config.get_predictor_types(
mode=Mode.PREDICT
)
except Exception: # pylint: disable=broad-exception-caught
msg = "Error while loading predictor:\n\n" + traceback.format_exc()
add_setup_failed_routes(app, started_at, msg)
return app

worker = make_worker(predictor_ref=cog_config.get_predictor_ref(mode=mode))
worker = make_worker(
predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async
)
runner = PredictionRunner(worker=worker)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
Expand Down Expand Up @@ -197,7 +201,7 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa

if cog_config.predictor_train_ref:
try:
TrainingInputType, TrainingOutputType = cog_config.get_predictor_types(
TrainingInputType, TrainingOutputType, _ = cog_config.get_predictor_types(
Mode.TRAIN
)

Expand Down
157 changes: 96 additions & 61 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
FatalWorkerException,
InvalidStateException,
)
from .helpers import AsyncStreamRedirector, StreamRedirector
from .helpers import SimpleStreamRedirector, StreamRedirector
from .scope import Scope, scope

if PYDANTIC_V2:
Expand Down Expand Up @@ -348,6 +348,8 @@ class _ChildWorker(_spawn.Process): # type: ignore
def __init__(
self,
predictor_ref: str,
*,
is_async: bool,
events: Connection,
tee_output: bool = True,
) -> None:
Expand All @@ -361,6 +363,7 @@ def __init__(

# for synchronous predictors only! async predictors use _tag_var instead
self._sync_tag: Optional[str] = None
self._is_async = is_async

super().__init__()

Expand All @@ -373,33 +376,41 @@ def run(self) -> None:
# Initially, we ignore SIGUSR1.
signal.signal(signal.SIGUSR1, signal.SIG_IGN)

async_redirector = AsyncStreamRedirector(
callback=self._stream_write_hook,
tee=self._tee_output,
)

with async_redirector:
self._setup(async_redirector)

# If setup didn't set the predictor, we're done here.
if not self._predictor:
return

predict = get_predict(self._predictor)
if inspect.iscoroutinefunction(predict) or inspect.isasyncgenfunction(predict):
asyncio.run(self._aloop(predict, async_redirector))
if self._is_async:
redirector = SimpleStreamRedirector(
callback=self._stream_write_hook,
tee=self._tee_output,
)
else:
# We use SIGUSR1 to signal an interrupt for cancelation.
signal.signal(signal.SIGUSR1, self._signal_handler)

self._loop(
predict,
StreamRedirector(
callback=self._stream_write_hook,
tee=self._tee_output,
),
redirector = StreamRedirector(
callback=self._stream_write_hook,
tee=self._tee_output,
)

with scope(Scope(record_metric=self.record_metric)), redirector:
self._predictor = self._load_predictor()

# If _load_predictor hasn't returned a predictor instance then
# it has sent a error Done event and we're done here.
if not self._predictor:
return

predict = get_predict(self._predictor)
if self._is_async:
assert isinstance(redirector, SimpleStreamRedirector)
self._setup(redirector)
asyncio.run(self._aloop(predict, redirector))
else:
# We use SIGUSR1 to signal an interrupt for cancelation.
signal.signal(signal.SIGUSR1, self._signal_handler)

assert isinstance(redirector, StreamRedirector)
self._setup(redirector)
self._loop(
predict,
redirector,
)

def send_cancel(self) -> None:
if self.is_alive() and self.pid:
os.kill(self.pid, signal.SIGUSR1)
Expand All @@ -417,11 +428,34 @@ def _current_tag(self) -> Optional[str]:
return tag
return self._sync_tag

def _setup(self, redirector: AsyncStreamRedirector) -> None:
def _load_predictor(self) -> Optional[BasePredictor]:
done = Done()
wait_for_env()
try:
self._predictor = load_predictor_from_ref(self._predictor_ref)
return load_predictor_from_ref(self._predictor_ref)
except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
self._events.send(Envelope(event=done))
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
self._events.send(Envelope(event=done))
raise

return None

def _setup(
self, redirector: Union[StreamRedirector, SimpleStreamRedirector]
) -> None:
done = Done()
try:
assert self._predictor

# Could be a function or a class
if hasattr(self._predictor, "setup"):
run_setup(self._predictor)
Expand Down Expand Up @@ -456,47 +490,42 @@ def _loop(
predict: Callable[..., Any],
redirector: StreamRedirector,
) -> None:
with scope(self._loop_scope()), redirector:
while True:
e = cast(Envelope, self._events.recv())
if isinstance(e.event, Cancel):
continue # Ignored in sync predictors.
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
self._predict(e.tag, e.event.payload, predict, redirector)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
while True:
e = cast(Envelope, self._events.recv())
if isinstance(e.event, Cancel):
continue # Ignored in sync predictors.
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
self._predict(e.tag, e.event.payload, predict, redirector)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)

async def _aloop(
self,
predict: Callable[..., Any],
redirector: AsyncStreamRedirector,
redirector: SimpleStreamRedirector,
) -> None:
# Unwrap and replace the events connection with an async one.
assert isinstance(self._events, LockedConnection)
self._events = AsyncConnection(self._events.connection)

task = None

with scope(self._loop_scope()), redirector:
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = asyncio.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
if task:
await task

def _loop_scope(self) -> Scope:
return Scope(record_metric=self.record_metric)
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = asyncio.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
print(f"Got unexpected event: {e.event}", file=sys.stderr)
if task:
await task

def _predict(
self,
Expand Down Expand Up @@ -554,7 +583,7 @@ async def _apredict(
tag: Optional[str],
payload: Dict[str, Any],
predict: Callable[..., Any],
redirector: AsyncStreamRedirector,
redirector: SimpleStreamRedirector,
) -> None:
_tag_var.set(tag)

Expand Down Expand Up @@ -606,7 +635,7 @@ async def _apredict(
@contextlib.contextmanager
def _handle_predict_error(
self,
redirector: Union[AsyncStreamRedirector, StreamRedirector],
redirector: Union[SimpleStreamRedirector, StreamRedirector],
tag: Optional[str],
) -> Iterator[None]:
done = Done()
Expand Down Expand Up @@ -680,10 +709,16 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None:


def make_worker(
predictor_ref: str, tee_output: bool = True, max_concurrency: int = 1
predictor_ref: str,
*,
is_async: bool,
tee_output: bool = True,
max_concurrency: int = 1,
) -> Worker:
parent_conn, child_conn = _spawn.Pipe()
child = _ChildWorker(predictor_ref, events=child_conn, tee_output=tee_output)
child = _ChildWorker(
predictor_ref, events=child_conn, tee_output=tee_output, is_async=is_async
)
parent = Worker(child=child, events=parent_conn, max_concurrency=max_concurrency)
return parent

Expand Down
Loading

0 comments on commit ed72ad5

Please sign in to comment.