Skip to content

Commit

Permalink
fix anthropic (#284)
Browse files Browse the repository at this point in the history
  • Loading branch information
Harold-lkk authored Dec 20, 2024
1 parent c4523f0 commit 784b12a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
8 changes: 6 additions & 2 deletions lagent/actions/web_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from asyncache import cached as acached
from bs4 import BeautifulSoup
from cachetools import TTLCache, cached
from duckduckgo_search import DDGS, AsyncDDGS

from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
Expand Down Expand Up @@ -78,6 +77,8 @@ def search(self, query: str, max_retry: int = 3) -> dict:

@acached(cache=TTLCache(maxsize=100, ttl=600))
async def asearch(self, query: str, max_retry: int = 3) -> dict:
from duckduckgo_search import AsyncDDGS

for attempt in range(max_retry):
try:
ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
Expand All @@ -92,6 +93,8 @@ async def asearch(self, query: str, max_retry: int = 3) -> dict:
raise Exception('Failed to get search results from DuckDuckGo after retries.')

async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
from duckduckgo_search import DDGS

ddgs = DDGS(**kwargs)
try:
response = await asyncio.wait_for(
Expand Down Expand Up @@ -215,7 +218,8 @@ class BraveSearch(BaseSearch):
topk (int): The number of search results returned in response from API search results.
region (str): The country code string. Specifies the country where the search results come from.
language (str): The language code string. Specifies the preferred language for the search results.
extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the
search results.
**kwargs: Any other parameters related to the Brave Search API. Find more details at
https://api.search.brave.com/app/documentation/web-search/get-started.
"""
Expand Down
14 changes: 9 additions & 5 deletions lagent/llms/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from anthropic import NOT_GIVEN
from requests.exceptions import ProxyError

from lagent.llms import AsyncBaseAPILLM, BaseAPILLM
from .base_api import AsyncBaseAPILLM, BaseAPILLM


class ClaudeAPI(BaseAPILLM):
Expand Down Expand Up @@ -222,7 +222,12 @@ def __init__(
retry: int = 5,
key: Union[str, List[str]] = 'ENV',
proxies: Optional[Dict] = None,
meta_template: Optional[Dict] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='user'),
],
temperature: float = NOT_GIVEN,
max_new_tokens: int = 512,
top_p: float = NOT_GIVEN,
Expand Down Expand Up @@ -278,8 +283,7 @@ async def chat(
assert isinstance(inputs, list)
gen_params = {**self.gen_params, **gen_params}
tasks = [
self._chat(self.template_parser(messages), **gen_params)
for messages in ([inputs] if isinstance(inputs[0], dict) else inputs)
self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs)
]
ret = await asyncio.gather(*tasks)
return ret[0] if isinstance(inputs[0], dict) else ret
Expand Down Expand Up @@ -333,7 +337,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
str: The generated string.
"""
assert isinstance(messages, list)

messages = self.template_parser(messages)
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params)
max_num_retries = 0

Expand Down
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
duckduckgo_search==5.3.1b1
google-search-results
lmdeploy>=0.2.5
pillow
Expand Down
1 change: 0 additions & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ arxiv
asyncache
asyncer
distro
duckduckgo_search==5.3.1b1
filelock
func_timeout
griffe<1.0
Expand Down

0 comments on commit 784b12a

Please sign in to comment.