Skip to content

Commit

Permalink
End-to-end support for concurrent async models (#2066)
Browse files Browse the repository at this point in the history
This builds on the work in #2057 and wires it up end-to-end.

We can now support async models with a max concurrency configured, and submit
multiple predictions concurrently to them.

We only support python 3.11 for async models; this is so that we can use
asyncio.TaskGroup to keep track of multiple predictions in flight and ensure
they all complete when shutting down.

The cog http server was already async, but at one point it called wait() on a
concurrent.futures.Future() which blocked the event loop and therefore prevented
concurrent prediction requests (when not using prefer-async, which is how the
tests run).  I have updated this code to wait on asyncio.wrap_future(fut)
instead which does not block the event loop.  As part of this I have updated the
training endpoints to also be asynchronous.

We now have three places in the code which keep track of how many predictions
are in flight: PredictionRunner, Worker and _ChildWorker all do their own
bookkeeping. I'm not sure this is the best design but it works.

The code is now an uneasy mix of threaded and asyncio code.  This is evident in
the usage of threading.Lock, which wouldn't be needed if we were 100% async (and
I'm not sure if it's actually needed currently; I just added it to be safe).

Co-authored-by: Aron Carroll <[email protected]>
  • Loading branch information
philandstuff and aron authored Dec 6, 2024
1 parent e181041 commit 3ca6205
Show file tree
Hide file tree
Showing 19 changed files with 502 additions and 172 deletions.
2 changes: 1 addition & 1 deletion pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func buildCommand(cmd *cobra.Command, args []string) error {
imageName = config.DockerImageName(projectDir)
}

err = config.ValidateModelPythonVersion(cfg.Build.PythonVersion)
err = config.ValidateModelPythonVersion(cfg)
if err != nil {
return err
}
Expand Down
28 changes: 20 additions & 8 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ var (
// TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1)

const (
MinimumMajorPythonVersion int = 3
MinimumMinorPythonVersion int = 8
MinimumMajorCudaVersion int = 11
MinimumMajorPythonVersion int = 3
MinimumMinorPythonVersion int = 8
MinimumMinorPythonVersionForConcurrency int = 11
MinimumMajorCudaVersion int = 11
)

type RunItem struct {
Expand All @@ -58,16 +59,21 @@ type Build struct {
pythonRequirementsContent []string
}

type Concurrency struct {
Max int `json:"max,omitempty" yaml:"max"`
}

type Example struct {
Input map[string]string `json:"input" yaml:"input"`
Output string `json:"output" yaml:"output"`
}

type Config struct {
Build *Build `json:"build" yaml:"build"`
Image string `json:"image,omitempty" yaml:"image"`
Predict string `json:"predict,omitempty" yaml:"predict"`
Train string `json:"train,omitempty" yaml:"train"`
Build *Build `json:"build" yaml:"build"`
Image string `json:"image,omitempty" yaml:"image"`
Predict string `json:"predict,omitempty" yaml:"predict"`
Train string `json:"train,omitempty" yaml:"train"`
Concurrency *Concurrency `json:"concurrency,omitempty" yaml:"concurrency"`
}

func DefaultConfig() *Config {
Expand Down Expand Up @@ -244,7 +250,9 @@ func splitPythonVersion(version string) (major int, minor int, err error) {
return major, minor, nil
}

func ValidateModelPythonVersion(version string) error {
func ValidateModelPythonVersion(cfg *Config) error {
version := cfg.Build.PythonVersion

// we check for minimum supported here
major, minor, err := splitPythonVersion(version)
if err != nil {
Expand All @@ -255,6 +263,10 @@ func ValidateModelPythonVersion(version string) error {
return fmt.Errorf("minimum supported Python version is %d.%d. requested %s",
MinimumMajorPythonVersion, MinimumMinorPythonVersion, version)
}
if cfg.Concurrency != nil && cfg.Concurrency.Max > 1 && minor < MinimumMinorPythonVersionForConcurrency {
return fmt.Errorf("when concurrency.max is set, minimum supported Python version is %d.%d. requested %s",
MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency, version)
}
return nil
}

Expand Down
80 changes: 45 additions & 35 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,68 @@ import (

func TestValidateModelPythonVersion(t *testing.T) {
testCases := []struct {
name string
input string
expectedErr bool
name string
pythonVersion string
concurrencyMax int
expectedErr string
}{
{
name: "ValidVersion",
input: "3.12",
expectedErr: false,
name: "ValidVersion",
pythonVersion: "3.12",
},
{
name: "MinimumVersion",
input: "3.8",
expectedErr: false,
name: "MinimumVersion",
pythonVersion: "3.8",
},
{
name: "FullyQualifiedVersion",
input: "3.12.1",
expectedErr: false,
name: "MinimumVersionForConcurrency",
pythonVersion: "3.11",
concurrencyMax: 5,
},
{
name: "InvalidFormat",
input: "3-12",
expectedErr: true,
name: "TooOldForConcurrency",
pythonVersion: "3.8",
concurrencyMax: 5,
expectedErr: "when concurrency.max is set, minimum supported Python version is 3.11. requested 3.8",
},
{
name: "InvalidMissingMinor",
input: "3",
expectedErr: true,
name: "FullyQualifiedVersion",
pythonVersion: "3.12.1",
},
{
name: "LessThanMinimum",
input: "3.7",
expectedErr: true,
name: "InvalidFormat",
pythonVersion: "3-12",
expectedErr: "invalid Python version format: missing minor version in 3-12",
},
{
name: "InvalidMissingMinor",
pythonVersion: "3",
expectedErr: "invalid Python version format: missing minor version in 3",
},
{
name: "LessThanMinimum",
pythonVersion: "3.7",
expectedErr: "minimum supported Python version is 3.8. requested 3.7",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateModelPythonVersion(tc.input)
if tc.expectedErr {
require.Error(t, err)
cfg := &Config{
Build: &Build{
PythonVersion: tc.pythonVersion,
},
}
if tc.concurrencyMax != 0 {
// the Concurrency key is optional, only populate it if
// concurrencyMax is a non-default value
cfg.Concurrency = &Concurrency{
Max: tc.concurrencyMax,
}
}
err := ValidateModelPythonVersion(cfg)
if tc.expectedErr != "" {
require.ErrorContains(t, err, tc.expectedErr)
} else {
require.NoError(t, err)
}
Expand Down Expand Up @@ -649,17 +670,6 @@ func TestBlankBuild(t *testing.T) {
require.Equal(t, false, config.Build.GPU)
}

func TestModelPythonVersionValidation(t *testing.T) {
err := ValidateModelPythonVersion("3.8")
require.NoError(t, err)
err = ValidateModelPythonVersion("3.8.1")
require.NoError(t, err)
err = ValidateModelPythonVersion("3.7")
require.Equal(t, "minimum supported Python version is 3.8. requested 3.7", err.Error())
err = ValidateModelPythonVersion("3.7.1")
require.Equal(t, "minimum supported Python version is 3.8. requested 3.7.1", err.Error())
}

func TestSplitPinnedPythonRequirement(t *testing.T) {
testCases := []struct {
input string
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tests = [
"numpy",
"pillow",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
"pytest-timeout",
"pytest-xdist",
Expand Down Expand Up @@ -70,6 +71,9 @@ reportUnusedExpression = "warning"
[tool.pyright.defineConstant]
PYDANTIC_V2 = true

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"

[tool.setuptools]
include-package-data = false

Expand Down
2 changes: 2 additions & 0 deletions python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .mimetypes_ext import install_mime_extensions
from .server.scope import current_scope, emit_metric
from .types import (
AsyncConcatenateIterator,
ConcatenateIterator,
ExperimentalFeatureWarning,
File,
Expand All @@ -26,6 +27,7 @@
"__version__",
"current_scope",
"emit_metric",
"AsyncConcatenateIterator",
"BaseModel",
"BasePredictor",
"ConcatenateIterator",
Expand Down
7 changes: 7 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP"
COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP"
COG_GPU_ENV_VAR = "COG_GPU"
COG_MAX_CONCURRENCY_ENV_VAR = "COG_MAX_CONCURRENCY"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

Expand Down Expand Up @@ -101,6 +102,12 @@ def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))

@property
@env_property(COG_MAX_CONCURRENCY_ENV_VAR)
def max_concurrency(self) -> int:
"""The maximum concurrency of predictions supported by this model. Defaults to 1."""
return int(self._cog_config.get("concurrency", {}).get("max", 1))

def _predictor_code(
self,
module_path: str,
Expand Down
11 changes: 5 additions & 6 deletions python/cog/server/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
callback: Callable[[str, str], None],
tee: bool = False,
) -> None:
super().__init__(buffer, line_buffering=True)
super().__init__(buffer)

self._callback = callback
self._tee = tee
Expand All @@ -44,11 +44,10 @@ def write(self, s: str) -> int:
self._buffer.append(s)
if self._tee:
super().write(s)
else:
# If we're not teeing, we have to handle automatic flush on
# newline. When `tee` is true, this is handled by the write method.
if "\n" in s or "\r" in s:
self.flush()

if "\n" in s or "\r" in s:
self.flush()

return length

def flush(self) -> None:
Expand Down
22 changes: 12 additions & 10 deletions python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ async def start_shutdown() -> Any:
return app

worker = make_worker(
predictor_ref=cog_config.get_predictor_ref(mode=mode), is_async=is_async
predictor_ref=cog_config.get_predictor_ref(mode=mode),
is_async=is_async,
max_concurrency=cog_config.max_concurrency,
)
runner = PredictionRunner(worker=worker)
runner = PredictionRunner(worker=worker, max_concurrency=cog_config.max_concurrency)

class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)):
pass
Expand Down Expand Up @@ -219,7 +221,7 @@ class TrainingRequest(
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train(
async def train(
request: TrainingRequest = Body(default=None),
prefer: Optional[str] = Header(default=None),
traceparent: Optional[str] = Header(
Expand All @@ -232,7 +234,7 @@ def train(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
Expand All @@ -243,7 +245,7 @@ def train(
response_model=TrainingResponse,
response_model_exclude_unset=True,
)
def train_idempotent(
async def train_idempotent(
training_id: str = Path(..., title="Training ID"),
request: TrainingRequest = Body(..., title="Training Request"),
prefer: Optional[str] = Header(default=None),
Expand Down Expand Up @@ -280,7 +282,7 @@ def train_idempotent(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=TrainingResponse,
respond_async=respond_async,
Expand Down Expand Up @@ -359,7 +361,7 @@ async def predict(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
Expand Down Expand Up @@ -407,13 +409,13 @@ async def predict_idempotent(
respond_async = prefer == "respond-async"

with trace_context(make_trace_context(traceparent, tracestate)):
return _predict(
return await _predict(
request=request,
response_type=PredictionResponse,
respond_async=respond_async,
)

def _predict(
async def _predict(
*,
request: Optional[PredictionRequest],
response_type: Type[schema.PredictionResponse],
Expand Down Expand Up @@ -455,7 +457,7 @@ def _predict(
)

# Otherwise, wait for the prediction to complete...
predict_task.wait()
await predict_task.wait_async()

# ...and return the result.
if PYDANTIC_V2:
Expand Down
Loading

0 comments on commit 3ca6205

Please sign in to comment.