Skip to content

Commit

Permalink
update pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
mertalev committed Oct 6, 2024
1 parent a8d16fe commit 121cbed
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 77 deletions.
24 changes: 13 additions & 11 deletions machine-learning/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@
from socket import socket

from gunicorn.arbiter import Arbiter
from pydantic import BaseModel, BaseSettings
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
from uvicorn.workers import UvicornWorker


class PreloadModelData(BaseModel):
clip: str | None
facial_recognition: str | None
clip: str | None = None
facial_recognition: str | None = None


class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="MACHINE_LEARNING_",
case_sensitive=False,
env_nested_delimiter="__",
protected_namespaces=("settings_",),
)

cache_folder: Path = Path("/cache")
model_ttl: int = 300
model_ttl_poll_s: int = 10
Expand All @@ -34,19 +42,13 @@ class Settings(BaseSettings):
ann_tuning_level: int = 2
preload: PreloadModelData | None = None

class Config:
env_prefix = "MACHINE_LEARNING_"
case_sensitive = False
env_nested_delimiter = "__"


class LogSettings(BaseSettings):
model_config = SettingsConfigDict(case_sensitive=False)

immich_log_level: str = "info"
no_color: bool = False

class Config:
case_sensitive = False


_clean_name = str.maketrans(":\\/", "___", ".")

Expand Down
13 changes: 6 additions & 7 deletions machine-learning/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi.responses import ORJSONResponse
from fastapi.responses import ORJSONResponse, PlainTextResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
from pydantic import ValidationError
Expand All @@ -35,7 +35,6 @@
ModelType,
PipelineRequest,
T,
TextResponse,
)

MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
Expand Down Expand Up @@ -127,14 +126,14 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
app = FastAPI(lifespan=lifespan)


@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
return {"message": "Immich ML"}
@app.get("/", response_class=ORJSONResponse)
async def root() -> MessageResponse:
return ORJSONResponse({"message": "Immich ML"})


@app.get("/ping", response_model=TextResponse)
@app.get("/ping", response_class=PlainTextResponse)
def ping() -> str:
return "pong"
return PlainTextResponse("pong")


@app.post("/predict", dependencies=[Depends(update_state)])
Expand Down
10 changes: 3 additions & 7 deletions machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from enum import Enum
from typing import Any, Literal, Protocol, TypedDict, TypeGuard, TypeVar
from typing import Any, Literal, Protocol, TypeGuard, TypeVar

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel
from typing_extensions import TypedDict


class StrEnum(str, Enum):
Expand All @@ -13,11 +13,7 @@ def __str__(self) -> str:
return self.value


class TextResponse(BaseModel):
__root__: str


class MessageResponse(BaseModel):
class MessageResponse(TypedDict):
message: str


Expand Down
17 changes: 16 additions & 1 deletion machine-learning/app/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,26 @@ async def test_falls_back_to_onnx_if_other_format_does_not_exist(
mock_model.model_format = ModelFormat.ONNX


def test_root_endpoint(deployed_app: TestClient) -> None:
response = deployed_app.get("http://localhost:3003")

body = response.json()
assert response.status_code == 200
assert body == {"message": "Immich ML"}


def test_ping_endpoint(deployed_app: TestClient) -> None:
response = deployed_app.get("http://localhost:3003/ping")

assert response.status_code == 200
assert response.text == "pong"


@pytest.mark.skipif(
not settings.test_full,
reason="More time-consuming since it deploys the app and loads models.",
)
class TestEndpoints:
class TestPredictionEndpoints:
def test_clip_image_endpoint(
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
) -> None:
Expand Down
Loading

0 comments on commit 121cbed

Please sign in to comment.