diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 3880d35f8..567178432 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,21 +1,19 @@ import functools +from .base_lm import BaseLM import logging import os import uuid from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Literal, Optional +import litellm import ujson -from litellm import Router -from litellm.router import RetryPolicy from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class from dspy.utils.callback import BaseCallback, with_callbacks -from .base_lm import BaseLM logger = logging.getLogger(__name__) @@ -34,7 +32,7 @@ def __init__( cache: bool = True, launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, - num_retries: int = 8, + num_retries: int = 3, **kwargs, ): """ @@ -165,55 +163,6 @@ def copy(self, **kwargs): return new_instance -@dataclass(frozen=True) -class _ProviderAPIConfig: - """ - API configurations for a provider (e.g. OpenAI, Azure OpenAI) - """ - - api_key: Optional[str] - api_base: Optional[str] - api_version: Optional[str] - # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication. - # For all other providers, this field is empty - azure_ad_token: Optional[str] - - -def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig: - """ - Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the - provider corresponding to the given model. - - Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating - the input dictionary. - """ - provider = _get_provider(model) - api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY") - api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE") - api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION") - if "azure" in provider: - azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN") - else: - azure_ad_token = None - return _ProviderAPIConfig( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - ) - - -def _get_provider(model: str) -> str: - """ - Extract the provider name from the model string of the format "/", - e.g. "openai/gpt-4". - - TODO: Not all the models are in the format of "provider/model" - """ - model = model.split("/", 1) - return model[0] if len(model) > 1 else "openai" - - @functools.lru_cache(maxsize=None) def cached_litellm_completion(request, num_retries: int): return litellm_completion( @@ -225,69 +174,13 @@ def cached_litellm_completion(request, num_retries: int): def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs) - router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config) - return router.completion( + return litellm.completion( + num_retries=num_retries, cache=cache, **kwargs, ) -@functools.lru_cache(maxsize=None) -def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router: - """ - Get a LiteLLM router for the given model with the specified number of retries - for transient errors. - - Args: - model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4'). - num_retries: The number of times to retry a request if it fails transiently due to - network error, rate limiting, etc. Requests are retried with exponential - backoff. - api_config: The API configurations (keys, base URL, etc.) for the provider - (OpenAI, Azure OpenAI, etc.) corresponding to the given model. - - Returns: - A LiteLLM router instance that can be used to query the given model. - """ - retry_policy = RetryPolicy( - TimeoutErrorRetries=num_retries, - RateLimitErrorRetries=num_retries, - InternalServerErrorRetries=num_retries, - # We don't retry on errors that are unlikely to be transient - # (e.g. bad request, invalid auth credentials) - BadRequestErrorRetries=0, - AuthenticationErrorRetries=0, - ContentPolicyViolationErrorRetries=0, - ) - - # LiteLLM routers must specify a `model_list`, which maps model names passed - # to `completions()` into actual LiteLLM model names. For our purposes, the - # model name is the same as the LiteLLM model name, so we add a single - # entry to the `model_list` that maps the model name to itself - litellm_params = { - "model": model, - } - if api_config.api_key is not None: - litellm_params["api_key"] = api_config.api_key - if api_config.api_base is not None: - litellm_params["api_base"] = api_config.api_base - if api_config.api_version is not None: - litellm_params["api_version"] = api_config.api_version - if api_config.azure_ad_token is not None: - litellm_params["azure_ad_token"] = api_config.azure_ad_token - model_list = [ - { - "model_name": model, - "litellm_params": litellm_params, - } - ] - return Router( - model_list=model_list, - retry_policy=retry_policy, - ) - - @functools.lru_cache(maxsize=None) def cached_litellm_text_completion(request, num_retries: int): return litellm_text_completion( @@ -299,18 +192,25 @@ def cached_litellm_text_completion(request, num_retries: int): def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - model = kwargs.pop("model") - api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs) - model_name = model.split("/", 1)[-1] - text_completion_model_name = f"text-completion-openai/{model_name}" + + # Extract the provider and model from the model string. + # TODO: Not all the models are in the format of "provider/model" + model = kwargs.pop("model").split("/", 1) + provider, model = model[0] if len(model) > 1 else "openai", model[-1] + + # Use the API key and base from the kwargs, or from the environment. + api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") + api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config) - return router.text_completion( + return litellm.text_completion( cache=cache, - model=text_completion_model_name, + model=f"text-completion-openai/{model}", + api_key=api_key, + api_base=api_base, prompt=prompt, + num_retries=num_retries, **kwargs, ) diff --git a/poetry.lock b/poetry.lock index a8ee6afa8..54ccaf2df 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -2411,13 +2411,13 @@ tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pyt [[package]] name = "litellm" -version = "1.49.1" +version = "1.51.0" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.49.1-py3-none-any.whl", hash = "sha256:2ba6689fe4ea3b0d69f56f2843caff6422497489e6252943b13ef1463f016728"}, - {file = "litellm-1.49.1.tar.gz", hash = "sha256:f51450ad823c8bdf057017009ae8bcce1a2810690b2f0d9dcdaff04ddc68209a"}, + {file = "litellm-1.51.0-py3-none-any.whl", hash = "sha256:0b2c20d116834166c8440e5698d7d927dbcc78fcaa08ce0c5cbea2d0de55ec6c"}, + {file = "litellm-1.51.0.tar.gz", hash = "sha256:8bf648677ee145a8fe5054a2e3f3a34895b9ab65a6015e4b94efca7ef406f466"}, ] [package.dependencies] @@ -2426,7 +2426,7 @@ click = "*" importlib-metadata = ">=6.8.0" jinja2 = ">=3.1.2,<4.0.0" jsonschema = ">=4.22.0,<5.0.0" -openai = ">=1.51.0" +openai = ">=1.52.0" pydantic = ">=2.0.0,<3.0.0" python-dotenv = ">=0.2.0" requests = ">=2.31.0,<3.0.0" @@ -7862,4 +7862,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "4c0c0eda720efe7fbc74f58ade43fcf01f61ee8295154dd74a1a70d6ddc30280" +content-hash = "92c91613bb51ec6d493672baf6c0d509ebb16ec0c7dd8d23f9cfd6e4654972cf" diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 7b3a02481..da4029e95 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,145 +1,145 @@ from unittest import mock import pytest -from litellm.router import RetryPolicy - -from dspy.clients.lm import LM, _get_litellm_router - - -@pytest.mark.parametrize("keys_in_env_vars", [True, False]) -def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch): - model_name = "openai/gpt4o" - num_retries = 17 - temperature = 0.5 - max_tokens = 100 - prompt = "Hello, world!" - api_version = "2024-02-01" - api_key = "apikey" - - lm_kwargs = { - "model": model_name, - "model_type": "chat", - "num_retries": num_retries, - "temperature": temperature, - "max_tokens": max_tokens, - } - if keys_in_env_vars: - api_base = "http://testfromenv.com" - monkeypatch.setenv("OPENAI_API_KEY", api_key) - monkeypatch.setenv("OPENAI_API_BASE", api_base) - monkeypatch.setenv("OPENAI_API_VERSION", api_version) - else: - api_base = "http://test.com" - lm_kwargs["api_key"] = api_key - lm_kwargs["api_base"] = api_base - lm_kwargs["api_version"] = api_version - - lm = LM(**lm_kwargs) - - MockRouter = mock.MagicMock() - mock_completion = mock.MagicMock() - MockRouter.completion = mock_completion - - with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: - lm(prompt=prompt) - - MockRouterConstructor.assert_called_once_with( - model_list=[ - { - "model_name": model_name, - "litellm_params": { - "model": model_name, - "api_key": api_key, - "api_base": api_base, - "api_version": api_version, - }, - } - ], - retry_policy=RetryPolicy( - TimeoutErrorRetries=num_retries, - RateLimitErrorRetries=num_retries, - InternalServerErrorRetries=num_retries, - BadRequestErrorRetries=0, - AuthenticationErrorRetries=0, - ContentPolicyViolationErrorRetries=0, - ), - ) - mock_completion.assert_called_once_with( - model=model_name, - messages=[{"role": "user", "content": prompt}], - temperature=temperature, - max_tokens=max_tokens, - cache=mock.ANY, - ) - - -@pytest.mark.parametrize("keys_in_env_vars", [True, False]) -def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch): - model_name = "azure/gpt-3.5-turbo" - expected_model = "text-completion-openai/" + model_name.split("/")[-1] - num_retries = 17 - temperature = 0.5 - max_tokens = 100 - prompt = "Hello, world!" - api_version = "2024-02-01" - api_key = "apikey" - azure_ad_token = "adtoken" - - lm_kwargs = { - "model": model_name, - "model_type": "text", - "num_retries": num_retries, - "temperature": temperature, - "max_tokens": max_tokens, - } - if keys_in_env_vars: - api_base = "http://testfromenv.com" - monkeypatch.setenv("AZURE_API_KEY", api_key) - monkeypatch.setenv("AZURE_API_BASE", api_base) - monkeypatch.setenv("AZURE_API_VERSION", api_version) - monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token) - else: - api_base = "http://test.com" - lm_kwargs["api_key"] = api_key - lm_kwargs["api_base"] = api_base - lm_kwargs["api_version"] = api_version - lm_kwargs["azure_ad_token"] = azure_ad_token - - lm = LM(**lm_kwargs) - - MockRouter = mock.MagicMock() - mock_text_completion = mock.MagicMock() - MockRouter.text_completion = mock_text_completion - - with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: - lm(prompt=prompt) - - MockRouterConstructor.assert_called_once_with( - model_list=[ - { - "model_name": expected_model, - "litellm_params": { - "model": expected_model, - "api_key": api_key, - "api_base": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - }, - } - ], - retry_policy=RetryPolicy( - TimeoutErrorRetries=num_retries, - RateLimitErrorRetries=num_retries, - InternalServerErrorRetries=num_retries, - BadRequestErrorRetries=0, - AuthenticationErrorRetries=0, - ContentPolicyViolationErrorRetries=0, - ), - ) - mock_text_completion.assert_called_once_with( - model=expected_model, - prompt=prompt + "\n\nBEGIN RESPONSE:", - temperature=temperature, - max_tokens=max_tokens, - cache=mock.ANY, - ) +# from litellm.router import RetryPolicy + +# from dspy.clients.lm import LM, _get_litellm_router + + +# @pytest.mark.parametrize("keys_in_env_vars", [True, False]) +# def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch): +# model_name = "openai/gpt4o" +# num_retries = 17 +# temperature = 0.5 +# max_tokens = 100 +# prompt = "Hello, world!" +# api_version = "2024-02-01" +# api_key = "apikey" + +# lm_kwargs = { +# "model": model_name, +# "model_type": "chat", +# "num_retries": num_retries, +# "temperature": temperature, +# "max_tokens": max_tokens, +# } +# if keys_in_env_vars: +# api_base = "http://testfromenv.com" +# monkeypatch.setenv("OPENAI_API_KEY", api_key) +# monkeypatch.setenv("OPENAI_API_BASE", api_base) +# monkeypatch.setenv("OPENAI_API_VERSION", api_version) +# else: +# api_base = "http://test.com" +# lm_kwargs["api_key"] = api_key +# lm_kwargs["api_base"] = api_base +# lm_kwargs["api_version"] = api_version + +# lm = LM(**lm_kwargs) + +# MockRouter = mock.MagicMock() +# mock_completion = mock.MagicMock() +# MockRouter.completion = mock_completion + +# with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: +# lm(prompt=prompt) + +# MockRouterConstructor.assert_called_once_with( +# model_list=[ +# { +# "model_name": model_name, +# "litellm_params": { +# "model": model_name, +# "api_key": api_key, +# "api_base": api_base, +# "api_version": api_version, +# }, +# } +# ], +# retry_policy=RetryPolicy( +# TimeoutErrorRetries=num_retries, +# RateLimitErrorRetries=num_retries, +# InternalServerErrorRetries=num_retries, +# BadRequestErrorRetries=0, +# AuthenticationErrorRetries=0, +# ContentPolicyViolationErrorRetries=0, +# ), +# ) +# mock_completion.assert_called_once_with( +# model=model_name, +# messages=[{"role": "user", "content": prompt}], +# temperature=temperature, +# max_tokens=max_tokens, +# cache=mock.ANY, +# ) + + +# @pytest.mark.parametrize("keys_in_env_vars", [True, False]) +# def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch): +# model_name = "azure/gpt-3.5-turbo" +# expected_model = "text-completion-openai/" + model_name.split("/")[-1] +# num_retries = 17 +# temperature = 0.5 +# max_tokens = 100 +# prompt = "Hello, world!" +# api_version = "2024-02-01" +# api_key = "apikey" +# azure_ad_token = "adtoken" + +# lm_kwargs = { +# "model": model_name, +# "model_type": "text", +# "num_retries": num_retries, +# "temperature": temperature, +# "max_tokens": max_tokens, +# } +# if keys_in_env_vars: +# api_base = "http://testfromenv.com" +# monkeypatch.setenv("AZURE_API_KEY", api_key) +# monkeypatch.setenv("AZURE_API_BASE", api_base) +# monkeypatch.setenv("AZURE_API_VERSION", api_version) +# monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token) +# else: +# api_base = "http://test.com" +# lm_kwargs["api_key"] = api_key +# lm_kwargs["api_base"] = api_base +# lm_kwargs["api_version"] = api_version +# lm_kwargs["azure_ad_token"] = azure_ad_token + +# lm = LM(**lm_kwargs) + +# MockRouter = mock.MagicMock() +# mock_text_completion = mock.MagicMock() +# MockRouter.text_completion = mock_text_completion + +# with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: +# lm(prompt=prompt) + +# MockRouterConstructor.assert_called_once_with( +# model_list=[ +# { +# "model_name": expected_model, +# "litellm_params": { +# "model": expected_model, +# "api_key": api_key, +# "api_base": api_base, +# "api_version": api_version, +# "azure_ad_token": azure_ad_token, +# }, +# } +# ], +# retry_policy=RetryPolicy( +# TimeoutErrorRetries=num_retries, +# RateLimitErrorRetries=num_retries, +# InternalServerErrorRetries=num_retries, +# BadRequestErrorRetries=0, +# AuthenticationErrorRetries=0, +# ContentPolicyViolationErrorRetries=0, +# ), +# ) +# mock_text_completion.assert_called_once_with( +# model=expected_model, +# prompt=prompt + "\n\nBEGIN RESPONSE:", +# temperature=temperature, +# max_tokens=max_tokens, +# cache=mock.ANY, +# )