diff --git a/lagent/actions/web_browser.py b/lagent/actions/web_browser.py index 235e149..29f7594 100644 --- a/lagent/actions/web_browser.py +++ b/lagent/actions/web_browser.py @@ -18,7 +18,6 @@ from asyncache import cached as acached from bs4 import BeautifulSoup from cachetools import TTLCache, cached -from duckduckgo_search import DDGS, AsyncDDGS from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api from lagent.actions.parser import BaseParser, JsonParser @@ -78,6 +77,8 @@ def search(self, query: str, max_retry: int = 3) -> dict: @acached(cache=TTLCache(maxsize=100, ttl=600)) async def asearch(self, query: str, max_retry: int = 3) -> dict: + from duckduckgo_search import AsyncDDGS + for attempt in range(max_retry): try: ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy) @@ -92,6 +93,8 @@ async def asearch(self, query: str, max_retry: int = 3) -> dict: raise Exception('Failed to get search results from DuckDuckGo after retries.') async def _async_call_ddgs(self, query: str, **kwargs) -> dict: + from duckduckgo_search import DDGS + ddgs = DDGS(**kwargs) try: response = await asyncio.wait_for( @@ -215,7 +218,8 @@ class BraveSearch(BaseSearch): topk (int): The number of search results returned in response from API search results. region (str): The country code string. Specifies the country where the search results come from. language (str): The language code string. Specifies the preferred language for the search results. - extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results. + extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the + search results. **kwargs: Any other parameters related to the Brave Search API. Find more details at https://api.search.brave.com/app/documentation/web-search/get-started. """ diff --git a/lagent/llms/anthropic_llm.py b/lagent/llms/anthropic_llm.py index 8cd7802..fe50bfa 100644 --- a/lagent/llms/anthropic_llm.py +++ b/lagent/llms/anthropic_llm.py @@ -9,7 +9,7 @@ from anthropic import NOT_GIVEN from requests.exceptions import ProxyError -from lagent.llms import AsyncBaseAPILLM, BaseAPILLM +from .base_api import AsyncBaseAPILLM, BaseAPILLM class ClaudeAPI(BaseAPILLM): @@ -222,7 +222,12 @@ def __init__( retry: int = 5, key: Union[str, List[str]] = 'ENV', proxies: Optional[Dict] = None, - meta_template: Optional[Dict] = None, + meta_template: Optional[Dict] = [ + dict(role='system', api_role='system'), + dict(role='user', api_role='user'), + dict(role='assistant', api_role='assistant'), + dict(role='environment', api_role='user'), + ], temperature: float = NOT_GIVEN, max_new_tokens: int = 512, top_p: float = NOT_GIVEN, @@ -278,8 +283,7 @@ async def chat( assert isinstance(inputs, list) gen_params = {**self.gen_params, **gen_params} tasks = [ - self._chat(self.template_parser(messages), **gen_params) - for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) + self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) ] ret = await asyncio.gather(*tasks) return ret[0] if isinstance(inputs[0], dict) else ret @@ -333,7 +337,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: str: The generated string. """ assert isinstance(messages, list) - + messages = self.template_parser(messages) data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) max_num_retries = 0 diff --git a/requirements/optional.txt b/requirements/optional.txt index 75645db..0ae76df 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,3 +1,4 @@ +duckduckgo_search==5.3.1b1 google-search-results lmdeploy>=0.2.5 pillow diff --git a/requirements/runtime.txt b/requirements/runtime.txt index bb28b27..ac0b85c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -4,7 +4,6 @@ arxiv asyncache asyncer distro -duckduckgo_search==5.3.1b1 filelock func_timeout griffe<1.0