Skip to content

Commit

Permalink
Revert LiteLLM Router-based retries and upgrade poetry lock for litel…
Browse files Browse the repository at this point in the history
…lm 1.51.0 (#1762)

* Revert LiteLLM Router-based retries and upgrade poetry lock for litellm 1.51.0

* Temporarily remove retry tests

* fix test
  • Loading branch information
okhat authored Nov 6, 2024
1 parent 1864a27 commit e654062
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 267 deletions.
138 changes: 19 additions & 119 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import functools
from .base_lm import BaseLM
import logging
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional

import litellm
import ujson
from litellm import Router
from litellm.router import RetryPolicy

from dspy.clients.finetune import FinetuneJob, TrainingMethod
from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class
from dspy.utils.callback import BaseCallback, with_callbacks

from .base_lm import BaseLM

logger = logging.getLogger(__name__)

Expand All @@ -34,7 +32,7 @@ def __init__(
cache: bool = True,
launch_kwargs: Optional[Dict[str, Any]] = None,
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 8,
num_retries: int = 3,
**kwargs,
):
"""
Expand Down Expand Up @@ -165,55 +163,6 @@ def copy(self, **kwargs):
return new_instance


@dataclass(frozen=True)
class _ProviderAPIConfig:
"""
API configurations for a provider (e.g. OpenAI, Azure OpenAI)
"""

api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
# Azure OpenAI with Azure AD auth requires an Azure AD token for authentication.
# For all other providers, this field is empty
azure_ad_token: Optional[str]


def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig:
"""
Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the
provider corresponding to the given model.
Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating
the input dictionary.
"""
provider = _get_provider(model)
api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY")
api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE")
api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION")
if "azure" in provider:
azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN")
else:
azure_ad_token = None
return _ProviderAPIConfig(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
)


def _get_provider(model: str) -> str:
"""
Extract the provider name from the model string of the format "<provider_name>/<model_name>",
e.g. "openai/gpt-4".
TODO: Not all the models are in the format of "provider/model"
"""
model = model.split("/", 1)
return model[0] if len(model) > 1 else "openai"


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request, num_retries: int):
return litellm_completion(
Expand All @@ -225,69 +174,13 @@ def cached_litellm_completion(request, num_retries: int):

def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs)
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config)
return router.completion(
return litellm.completion(
num_retries=num_retries,
cache=cache,
**kwargs,
)


@functools.lru_cache(maxsize=None)
def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router:
"""
Get a LiteLLM router for the given model with the specified number of retries
for transient errors.
Args:
model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4').
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
api_config: The API configurations (keys, base URL, etc.) for the provider
(OpenAI, Azure OpenAI, etc.) corresponding to the given model.
Returns:
A LiteLLM router instance that can be used to query the given model.
"""
retry_policy = RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
# We don't retry on errors that are unlikely to be transient
# (e.g. bad request, invalid auth credentials)
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
ContentPolicyViolationErrorRetries=0,
)

# LiteLLM routers must specify a `model_list`, which maps model names passed
# to `completions()` into actual LiteLLM model names. For our purposes, the
# model name is the same as the LiteLLM model name, so we add a single
# entry to the `model_list` that maps the model name to itself
litellm_params = {
"model": model,
}
if api_config.api_key is not None:
litellm_params["api_key"] = api_config.api_key
if api_config.api_base is not None:
litellm_params["api_base"] = api_config.api_base
if api_config.api_version is not None:
litellm_params["api_version"] = api_config.api_version
if api_config.azure_ad_token is not None:
litellm_params["azure_ad_token"] = api_config.azure_ad_token
model_list = [
{
"model_name": model,
"litellm_params": litellm_params,
}
]
return Router(
model_list=model_list,
retry_policy=retry_policy,
)


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request, num_retries: int):
return litellm_text_completion(
Expand All @@ -299,18 +192,25 @@ def cached_litellm_text_completion(request, num_retries: int):

def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
model = kwargs.pop("model")
api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs)
model_name = model.split("/", 1)[-1]
text_completion_model_name = f"text-completion-openai/{model_name}"

# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = kwargs.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]

# Use the API key and base from the kwargs, or from the environment.
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config)
return router.text_completion(
return litellm.text_completion(
cache=cache,
model=text_completion_model_name,
model=f"text-completion-openai/{model}",
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**kwargs,
)
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e654062

Please sign in to comment.