Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: hotfix create_with_completion failing for AdapterBase, ParallelBase, IterableBase and PartialBase #1103

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions instructor/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
overload,
Union,
Literal,
Generic,
Any,
)
from collections.abc import Generator, Iterable, Awaitable, AsyncGenerator
Expand All @@ -23,6 +24,14 @@
T = TypeVar("T", bound=Union[BaseModel, "Iterable[Any]", "Partial[Any]"])


class Creation(Generic[T]):
raw: Any # should be uniform completion type
processed: T

def __init__(self, raw: Any, processed: T) -> None:
self.raw, self.processed = raw, processed


class Instructor:
client: Any | None
create_fn: Callable[..., Any]
Expand All @@ -34,7 +43,7 @@ class Instructor:
def __init__(
self,
client: Any | None,
create: Callable[..., Any],
create: Callable[..., Creation[T]],
mode: instructor.Mode = instructor.Mode.TOOLS,
provider: Provider = Provider.OPENAI,
hooks: Hooks | None = None,
Expand Down Expand Up @@ -169,7 +178,7 @@ def create(
) -> T | Any | Awaitable[T] | Awaitable[Any]:
kwargs = self.handle_kwargs(kwargs)

return self.create_fn(
creation: Creation[T] = self.create_fn(
response_model=response_model,
messages=messages,
max_retries=max_retries,
Expand All @@ -179,6 +188,7 @@ def create(
hooks=self.hooks,
**kwargs,
)
return creation.processed

@overload
def create_partial(
Expand Down Expand Up @@ -219,7 +229,7 @@ def create_partial(
kwargs = self.handle_kwargs(kwargs)

response_model = instructor.Partial[response_model] # type: ignore
return self.create_fn(
creation: Creation[Generator[T, None, None]] = self.create_fn(
messages=messages,
response_model=response_model,
max_retries=max_retries,
Expand All @@ -229,6 +239,7 @@ def create_partial(
hooks=self.hooks,
**kwargs,
)
return creation.processed

@overload
def create_iterable(
Expand Down Expand Up @@ -268,7 +279,7 @@ def create_iterable(
kwargs = self.handle_kwargs(kwargs)

response_model = Iterable[response_model] # type: ignore
return self.create_fn(
creation: Creation[Generator[T, None, None]] = self.create_fn(
messages=messages,
response_model=response_model,
max_retries=max_retries,
Expand All @@ -278,6 +289,7 @@ def create_iterable(
hooks=self.hooks,
**kwargs,
)
return creation.processed

@overload
def create_with_completion(
Expand Down Expand Up @@ -314,7 +326,7 @@ def create_with_completion(
**kwargs: Any,
) -> tuple[T, Any] | Awaitable[tuple[T, Any]]:
kwargs = self.handle_kwargs(kwargs)
model = self.create_fn(
creation: Creation[T] = self.create_fn(
messages=messages,
response_model=response_model,
max_retries=max_retries,
Expand All @@ -324,7 +336,7 @@ def create_with_completion(
hooks=self.hooks,
**kwargs,
)
return model, model._raw_response
return creation.processed, creation.raw

def handle_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -356,7 +368,7 @@ class AsyncInstructor(Instructor):
def __init__(
self,
client: Any | None,
create: Callable[..., Any],
create: Callable[..., Creation[T]],
mode: instructor.Mode = instructor.Mode.TOOLS,
provider: Provider = Provider.OPENAI,
hooks: Hooks | None = None,
Expand All @@ -380,7 +392,7 @@ async def create(
**kwargs: Any,
) -> T | Any:
kwargs = self.handle_kwargs(kwargs)
return await self.create_fn(
creation: Creation[T] = await self.create_fn(
response_model=response_model,
validation_context=validation_context,
context=context,
Expand All @@ -390,6 +402,7 @@ async def create(
hooks=self.hooks,
**kwargs,
)
return creation.processed

async def create_partial(
self,
Expand All @@ -403,7 +416,7 @@ async def create_partial(
) -> AsyncGenerator[T, None]:
kwargs = self.handle_kwargs(kwargs)
kwargs["stream"] = True
async for item in await self.create_fn(
creation: Creation[AsyncGenerator[T, None]] = await self.create_fn( # type: ignore
response_model=instructor.Partial[response_model], # type: ignore
validation_context=validation_context,
context=context,
Expand All @@ -412,7 +425,8 @@ async def create_partial(
strict=strict,
hooks=self.hooks,
**kwargs,
):
)
async for item in creation.processed:
yield item

async def create_iterable(
Expand All @@ -427,7 +441,7 @@ async def create_iterable(
) -> AsyncGenerator[T, None]:
kwargs = self.handle_kwargs(kwargs)
kwargs["stream"] = True
async for item in await self.create_fn(
async for creation in await self.create_fn(
response_model=Iterable[response_model],
validation_context=validation_context,
context=context,
Expand All @@ -437,6 +451,8 @@ async def create_iterable(
hooks=self.hooks,
**kwargs,
):
creation: Creation[T]
item = creation.processed
yield item

async def create_with_completion(
Expand All @@ -450,7 +466,7 @@ async def create_with_completion(
**kwargs: Any,
) -> tuple[T, Any]:
kwargs = self.handle_kwargs(kwargs)
response = await self.create_fn(
creation: Creation[T] = await self.create_fn(
response_model=response_model,
validation_context=validation_context,
context=context,
Expand All @@ -460,7 +476,7 @@ async def create_with_completion(
hooks=self.hooks,
**kwargs,
)
return response, response._raw_response
return creation.processed, creation.raw


@overload
Expand Down
9 changes: 5 additions & 4 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel

from instructor.client import Creation
from instructor.process_response import handle_response_model
from instructor.retry import retry_async, retry_sync
from instructor.utils import is_async
Expand All @@ -38,7 +39,7 @@ def __call__(
max_retries: int = 1,
*args: Any,
**kwargs: Any,
) -> T_Model: ...
) -> Creation[T_Model]: ...


class AsyncInstructorChatCompletionCreate(Protocol):
Expand All @@ -50,7 +51,7 @@ async def __call__(
max_retries: int = 1,
*args: Any,
**kwargs: Any,
) -> T_Model: ...
) -> Creation[T_Model]: ...


def handle_context(
Expand Down Expand Up @@ -145,7 +146,7 @@ async def new_create_async(
hooks: Hooks | None = None,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
) -> Creation[T_Model]:
context = handle_context(context, validation_context)

response_model, new_kwargs = handle_response_model(
Expand Down Expand Up @@ -176,7 +177,7 @@ def new_create_sync(
hooks: Hooks | None = None,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
) -> Creation[T_Model]:
context = handle_context(context, validation_context)

response_model, new_kwargs = handle_response_model(
Expand Down
11 changes: 7 additions & 4 deletions instructor/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from json import JSONDecodeError
from typing import Any, Callable, TypeVar

from instructor.client import Creation
from instructor.exceptions import InstructorRetryException
from instructor.hooks import Hooks
from instructor.mode import Mode
Expand Down Expand Up @@ -104,7 +105,7 @@ def retry_sync(
strict: bool | None = None,
mode: Mode = Mode.TOOLS,
hooks: Hooks | None = None,
) -> T_Model | None:
) -> Creation[T_Model] | None:
"""
Retry a synchronous function upon specified exceptions.

Expand Down Expand Up @@ -141,14 +142,15 @@ def retry_sync(
response = update_total_usage(
response=response, total_usage=total_usage
)
return process_response( # type: ignore
processed = process_response( # type: ignore
response=response,
response_model=response_model,
validation_context=context,
strict=strict,
mode=mode,
stream=kwargs.get("stream", False),
)
return Creation(processed=processed, raw=response)
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Parse error: {e}")
hooks.emit_parse_error(e)
Expand Down Expand Up @@ -184,7 +186,7 @@ async def retry_async(
strict: bool | None = None,
mode: Mode = Mode.TOOLS,
hooks: Hooks | None = None,
) -> T_Model | None:
) -> Creation[T_Model] | None:
"""
Retry an asynchronous function upon specified exceptions.

Expand Down Expand Up @@ -222,14 +224,15 @@ async def retry_async(
response=response, total_usage=total_usage
)

return await process_response_async(
processed = await process_response_async(
response=response,
response_model=response_model,
validation_context=context,
strict=strict,
mode=mode,
stream=kwargs.get("stream", False),
)
return Creation(processed=processed, raw=response)
except (ValidationError, JSONDecodeError, AsyncValidationError) as e:
logger.debug(f"Parse error: {e}")
hooks.emit_parse_error(e)
Expand Down