From 064f6d276d624fab03dd17d1430abafa4e52c969 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Wed, 30 Oct 2024 16:07:17 -0700 Subject: [PATCH] Remove retry_strategy in LM and handle no-docstring functions in ReAct (#1725) * Remove retry_strategy in LM and handle no-docstring functions in ReAct * exponential_backoff_retry tests --- dspy/clients/lm.py | 4 +--- dspy/predict/react.py | 8 ++++---- tests/clients/test_lm.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 7ba73c600..02fe17402 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -39,7 +39,7 @@ def __init__( cache: bool = True, launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, - num_retries: int = 8, + num_retries: int = 3, **kwargs, ): """ @@ -186,7 +186,6 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s kwargs = ujson.loads(request) return litellm.completion( num_retries=num_retries, - retry_strategy="exponential_backoff_retry", cache=cache, **kwargs, ) @@ -223,7 +222,6 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, api_base=api_base, prompt=prompt, num_retries=num_retries, - retry_strategy="exponential_backoff_retry", **kwargs, ) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index a26f538f0..e2eb1e0d8 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -12,7 +12,7 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic annotations_func = func if inspect.isfunction(func) else func.__call__ self.func = func self.name = name or getattr(func, '__name__', type(func).__name__) - self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description") + self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "") self.args = { k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel) else get_annotation_name(v) @@ -50,10 +50,10 @@ def __init__(self, signature, tools: list[Callable], max_iters=5): tools["finish"] = Tool(func=lambda **kwargs: "Completed.", name="finish", desc=finish_desc, args=finish_args) for idx, tool in enumerate(tools.values()): - desc = tool.desc.replace("\n", " ") args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str}) - desc = f"whose description is {desc}. It takes arguments {args} in JSON format." - instr.append(f"({idx+1}) {tool.name}, {desc}") + desc = (f", whose description is {tool.desc}." if tool.desc else ".").replace('\n', " ") + desc += f" It takes arguments {args} in JSON format." + instr.append(f"({idx+1}) {tool.name}{desc}") signature_ = ( dspy.Signature({**signature.input_fields}, "\n".join(instr)) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 8825dab95..ef5d85a9c 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -11,7 +11,7 @@ def test_lm_chat_respects_max_retries(): assert litellm_completion_api.call_count == 1 assert litellm_completion_api.call_args[1]["max_retries"] == 17 - assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" def test_lm_completions_respects_max_retries(): @@ -22,4 +22,4 @@ def test_lm_completions_respects_max_retries(): assert litellm_completion_api.call_count == 1 assert litellm_completion_api.call_args[1]["max_retries"] == 17 - assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"