From 1f30dce2c550d109d6a2cf833b50e387a0174eb0 Mon Sep 17 00:00:00 2001 From: ivanbelenky Date: Mon, 21 Oct 2024 02:04:03 -0300 Subject: [PATCH] first commit for creation wrapper --- instructor/client.py | 42 +++++++++++++++++++++++++++++------------- instructor/patch.py | 9 +++++---- instructor/retry.py | 11 +++++++---- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/instructor/client.py b/instructor/client.py index fa452c884..9aa2dff56 100644 --- a/instructor/client.py +++ b/instructor/client.py @@ -11,6 +11,7 @@ overload, Union, Literal, + Generic, Any, ) from collections.abc import Generator, Iterable, Awaitable, AsyncGenerator @@ -23,6 +24,14 @@ T = TypeVar("T", bound=Union[BaseModel, "Iterable[Any]", "Partial[Any]"]) +class Creation(Generic[T]): + raw: Any # should be uniform completion type + processed: T + + def __init__(self, raw: Any, processed: T) -> None: + self.raw, self.processed = raw, processed + + class Instructor: client: Any | None create_fn: Callable[..., Any] @@ -34,7 +43,7 @@ class Instructor: def __init__( self, client: Any | None, - create: Callable[..., Any], + create: Callable[..., Creation[T]], mode: instructor.Mode = instructor.Mode.TOOLS, provider: Provider = Provider.OPENAI, hooks: Hooks | None = None, @@ -169,7 +178,7 @@ def create( ) -> T | Any | Awaitable[T] | Awaitable[Any]: kwargs = self.handle_kwargs(kwargs) - return self.create_fn( + creation: Creation[T] = self.create_fn( response_model=response_model, messages=messages, max_retries=max_retries, @@ -179,6 +188,7 @@ def create( hooks=self.hooks, **kwargs, ) + return creation.processed @overload def create_partial( @@ -219,7 +229,7 @@ def create_partial( kwargs = self.handle_kwargs(kwargs) response_model = instructor.Partial[response_model] # type: ignore - return self.create_fn( + creation: Creation[Generator[T, None, None]] = self.create_fn( messages=messages, response_model=response_model, max_retries=max_retries, @@ -229,6 +239,7 @@ def create_partial( hooks=self.hooks, **kwargs, ) + return creation.processed @overload def create_iterable( @@ -268,7 +279,7 @@ def create_iterable( kwargs = self.handle_kwargs(kwargs) response_model = Iterable[response_model] # type: ignore - return self.create_fn( + creation: Creation[Generator[T, None, None]] = self.create_fn( messages=messages, response_model=response_model, max_retries=max_retries, @@ -278,6 +289,7 @@ def create_iterable( hooks=self.hooks, **kwargs, ) + return creation.processed @overload def create_with_completion( @@ -314,7 +326,7 @@ def create_with_completion( **kwargs: Any, ) -> tuple[T, Any] | Awaitable[tuple[T, Any]]: kwargs = self.handle_kwargs(kwargs) - model = self.create_fn( + creation: Creation[T] = self.create_fn( messages=messages, response_model=response_model, max_retries=max_retries, @@ -324,7 +336,7 @@ def create_with_completion( hooks=self.hooks, **kwargs, ) - return model, model._raw_response + return creation.processed, creation.raw def handle_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: """ @@ -356,7 +368,7 @@ class AsyncInstructor(Instructor): def __init__( self, client: Any | None, - create: Callable[..., Any], + create: Callable[..., Creation[T]], mode: instructor.Mode = instructor.Mode.TOOLS, provider: Provider = Provider.OPENAI, hooks: Hooks | None = None, @@ -380,7 +392,7 @@ async def create( **kwargs: Any, ) -> T | Any: kwargs = self.handle_kwargs(kwargs) - return await self.create_fn( + creation: Creation[T] = await self.create_fn( response_model=response_model, validation_context=validation_context, context=context, @@ -390,6 +402,7 @@ async def create( hooks=self.hooks, **kwargs, ) + return creation.processed async def create_partial( self, @@ -403,7 +416,7 @@ async def create_partial( ) -> AsyncGenerator[T, None]: kwargs = self.handle_kwargs(kwargs) kwargs["stream"] = True - async for item in await self.create_fn( + creation: Creation[AsyncGenerator[T, None]] = await self.create_fn( # type: ignore response_model=instructor.Partial[response_model], # type: ignore validation_context=validation_context, context=context, @@ -412,7 +425,8 @@ async def create_partial( strict=strict, hooks=self.hooks, **kwargs, - ): + ) + async for item in creation.processed: yield item async def create_iterable( @@ -427,7 +441,7 @@ async def create_iterable( ) -> AsyncGenerator[T, None]: kwargs = self.handle_kwargs(kwargs) kwargs["stream"] = True - async for item in await self.create_fn( + async for creation in await self.create_fn( response_model=Iterable[response_model], validation_context=validation_context, context=context, @@ -437,6 +451,8 @@ async def create_iterable( hooks=self.hooks, **kwargs, ): + creation: Creation[T] + item = creation.processed yield item async def create_with_completion( @@ -450,7 +466,7 @@ async def create_with_completion( **kwargs: Any, ) -> tuple[T, Any]: kwargs = self.handle_kwargs(kwargs) - response = await self.create_fn( + creation: Creation[T] = await self.create_fn( response_model=response_model, validation_context=validation_context, context=context, @@ -460,7 +476,7 @@ async def create_with_completion( hooks=self.hooks, **kwargs, ) - return response, response._raw_response + return creation.processed, creation.raw @overload diff --git a/instructor/patch.py b/instructor/patch.py index 2d1f340e9..0c1de1148 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -13,6 +13,7 @@ from openai import AsyncOpenAI, OpenAI from pydantic import BaseModel +from instructor.client import Creation from instructor.process_response import handle_response_model from instructor.retry import retry_async, retry_sync from instructor.utils import is_async @@ -38,7 +39,7 @@ def __call__( max_retries: int = 1, *args: Any, **kwargs: Any, - ) -> T_Model: ... + ) -> Creation[T_Model]: ... class AsyncInstructorChatCompletionCreate(Protocol): @@ -50,7 +51,7 @@ async def __call__( max_retries: int = 1, *args: Any, **kwargs: Any, - ) -> T_Model: ... + ) -> Creation[T_Model]: ... def handle_context( @@ -145,7 +146,7 @@ async def new_create_async( hooks: Hooks | None = None, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: + ) -> Creation[T_Model]: context = handle_context(context, validation_context) response_model, new_kwargs = handle_response_model( @@ -176,7 +177,7 @@ def new_create_sync( hooks: Hooks | None = None, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs, - ) -> T_Model: + ) -> Creation[T_Model]: context = handle_context(context, validation_context) response_model, new_kwargs = handle_response_model( diff --git a/instructor/retry.py b/instructor/retry.py index ffc41194d..b98c05548 100644 --- a/instructor/retry.py +++ b/instructor/retry.py @@ -6,6 +6,7 @@ from json import JSONDecodeError from typing import Any, Callable, TypeVar +from instructor.client import Creation from instructor.exceptions import InstructorRetryException from instructor.hooks import Hooks from instructor.mode import Mode @@ -104,7 +105,7 @@ def retry_sync( strict: bool | None = None, mode: Mode = Mode.TOOLS, hooks: Hooks | None = None, -) -> T_Model | None: +) -> Creation[T_Model] | None: """ Retry a synchronous function upon specified exceptions. @@ -141,7 +142,7 @@ def retry_sync( response = update_total_usage( response=response, total_usage=total_usage ) - return process_response( # type: ignore + processed = process_response( # type: ignore response=response, response_model=response_model, validation_context=context, @@ -149,6 +150,7 @@ def retry_sync( mode=mode, stream=kwargs.get("stream", False), ) + return Creation(processed=processed, raw=response) except (ValidationError, JSONDecodeError) as e: logger.debug(f"Parse error: {e}") hooks.emit_parse_error(e) @@ -184,7 +186,7 @@ async def retry_async( strict: bool | None = None, mode: Mode = Mode.TOOLS, hooks: Hooks | None = None, -) -> T_Model | None: +) -> Creation[T_Model] | None: """ Retry an asynchronous function upon specified exceptions. @@ -222,7 +224,7 @@ async def retry_async( response=response, total_usage=total_usage ) - return await process_response_async( + processed = await process_response_async( response=response, response_model=response_model, validation_context=context, @@ -230,6 +232,7 @@ async def retry_async( mode=mode, stream=kwargs.get("stream", False), ) + return Creation(processed=processed, raw=response) except (ValidationError, JSONDecodeError, AsyncValidationError) as e: logger.debug(f"Parse error: {e}") hooks.emit_parse_error(e)