Skip to content

Commit

Permalink
Mind search (#208)
Browse files Browse the repository at this point in the history
Support MindSearch
  • Loading branch information
Harold-lkk authored Jul 29, 2024
1 parent 139a810 commit 7167bbd
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 9 deletions.
3 changes: 2 additions & 1 deletion lagent/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .action_executor import ActionExecutor
from .arxiv_search import ArxivSearch
from .base_action import TOOL_REGISTRY, BaseAction, tool_api
from .bing_browser import BingBrowser
from .bing_map import BINGMap
from .builtin_actions import FinishAction, InvalidAction, NoAction
from .google_scholar_search import GoogleScholar
Expand All @@ -20,7 +21,7 @@
'GoogleScholar', 'IPythonInterpreter', 'IPythonInteractive',
'IPythonInteractiveManager', 'PythonInterpreter', 'PPT', 'BaseParser',
'JsonParser', 'TupleParser', 'tool_api', 'list_tools', 'get_tool_cls',
'get_tool'
'get_tool', 'BingBrowser'
]


Expand Down
270 changes: 270 additions & 0 deletions lagent/actions/bing_browser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import json
import logging
import random
import re
import time
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Tuple, Type, Union

import requests
from bs4 import BeautifulSoup
from cachetools import TTLCache, cached
from duckduckgo_search import DDGS

from lagent.actions import BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser


class BaseSearch:

def __init__(self, topk: int = 3, black_list: List[str] = None):
self.topk = topk
self.black_list = black_list

def _filter_results(self, results: List[tuple]) -> dict:
filtered_results = {}
count = 0
for url, snippet, title in results:
if all(domain not in url
for domain in self.black_list) and not url.endswith('.pdf'):
filtered_results[count] = {
'url': url,
'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
'title': title
}
count += 1
if count >= self.topk:
break
return filtered_results


class DuckDuckGoSearch(BaseSearch):

def __init__(self,
topk: int = 3,
black_list: List[str] = [
'enoN',
'youtube.com',
'bilibili.com',
'researchgate.net',
],
**kwargs):
self.proxy = kwargs.get('proxy')
self.timeout = kwargs.get('timeout', 10)
super().__init__(topk, black_list)

@cached(cache=TTLCache(maxsize=100, ttl=600))
def search(self, query: str, max_retry: int = 3) -> dict:
for attempt in range(max_retry):
try:
response = self._call_ddgs(
query, timeout=self.timeout, proxy=self.proxy)
return self._parse_response(response)
except Exception as e:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from DuckDuckGo after retries.')

def _call_ddgs(self, query: str, **kwargs) -> dict:
ddgs = DDGS(**kwargs)
response = ddgs.text(query.strip("'"), max_results=10)
return response

def _parse_response(self, response: dict) -> dict:
raw_results = []
for item in response:
raw_results.append(
(item['href'], item['description']
if 'description' in item else item['body'], item['title']))
return self._filter_results(raw_results)


class BingSearch(BaseSearch):

def __init__(self,
api_key: str,
region: str = 'zh-CN',
topk: int = 3,
black_list: List[str] = [
'enoN',
'youtube.com',
'bilibili.com',
'researchgate.net',
],
**kwargs):
self.api_key = api_key
self.market = region
self.proxy = kwargs.get('proxy')
super().__init__(topk, black_list)

@cached(cache=TTLCache(maxsize=100, ttl=600))
def search(self, query: str, max_retry: int = 3) -> dict:
for attempt in range(max_retry):
try:
response = self._call_bing_api(query)
return self._parse_response(response)
except Exception as e:
logging.exception(str(e))
warnings.warn(
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
time.sleep(random.randint(2, 5))
raise Exception(
'Failed to get search results from Bing Search after retries.')

def _call_bing_api(self, query: str) -> dict:
endpoint = 'https://api.bing.microsoft.com/v7.0/search'
params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
response = requests.get(
endpoint, headers=headers, params=params, proxies=self.proxy)
response.raise_for_status()
return response.json()

def _parse_response(self, response: dict) -> dict:
webpages = {
w['id']: w
for w in response.get('webPages', {}).get('value', [])
}
raw_results = []

for item in response.get('rankingResponse',
{}).get('mainline', {}).get('items', []):
if item['answerType'] == 'WebPages':
webpage = webpages.get(item['value']['id'])
if webpage:
raw_results.append(
(webpage['url'], webpage['snippet'], webpage['name']))
elif item['answerType'] == 'News' and item['value'][
'id'] == response.get('news', {}).get('id'):
for news in response.get('news', {}).get('value', []):
raw_results.append(
(news['url'], news['description'], news['name']))

return self._filter_results(raw_results)


class ContentFetcher:

def __init__(self, timeout: int = 5):
self.timeout = timeout

@cached(cache=TTLCache(maxsize=100, ttl=600))
def fetch(self, url: str) -> Tuple[bool, str]:
try:
response = requests.get(url, timeout=self.timeout)
response.raise_for_status()
html = response.content
except requests.RequestException as e:
return False, str(e)

text = BeautifulSoup(html, 'html.parser').get_text()
cleaned_text = re.sub(r'\n+', '\n', text)
return True, cleaned_text


class BingBrowser(BaseAction):
"""Wrapper around the Web Browser Tool.
"""

def __init__(self,
searcher_type: str = 'DuckDuckGoSearch',
timeout: int = 5,
black_list: Optional[List[str]] = [
'enoN',
'youtube.com',
'bilibili.com',
'researchgate.net',
],
topk: int = 20,
description: Optional[dict] = None,
parser: Type[BaseParser] = JsonParser,
enable: bool = True,
**kwargs):
self.searcher = eval(searcher_type)(
black_list=black_list, topk=topk, **kwargs)
self.fetcher = ContentFetcher(timeout=timeout)
self.search_results = None
super().__init__(description, parser, enable)

@tool_api
def search(self, query: Union[str, List[str]]) -> dict:
"""BING search API
Args:
query (List[str]): list of search query strings
"""
queries = query if isinstance(query, list) else [query]
search_results = {}

with ThreadPoolExecutor() as executor:
future_to_query = {
executor.submit(self.searcher.search, q): q
for q in queries
}

for future in as_completed(future_to_query):
query = future_to_query[future]
try:
results = future.result()
except Exception as exc:
warnings.warn(f'{query} generated an exception: {exc}')
else:
for result in results.values():
if result['url'] not in search_results:
search_results[result['url']] = result
else:
search_results[
result['url']]['summ'] += f"\n{result['summ']}"

self.search_results = {
idx: result
for idx, result in enumerate(search_results.values())
}
return self.search_results

@tool_api
def select(self, select_ids: List[int]) -> dict:
"""get the detailed content on the selected pages.
Args:
select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
"""
if not self.search_results:
raise ValueError('No search results to select from.')

new_search_results = {}
with ThreadPoolExecutor() as executor:
future_to_id = {
executor.submit(self.fetcher.fetch,
self.search_results[select_id]['url']):
select_id
for select_id in select_ids if select_id in self.search_results
}

for future in as_completed(future_to_id):
select_id = future_to_id[future]
try:
web_success, web_content = future.result()
except Exception as exc:
warnings.warn(f'{select_id} generated an exception: {exc}')
else:
if web_success:
self.search_results[select_id][
'content'] = web_content[:8192]
new_search_results[select_id] = self.search_results[
select_id].copy()
new_search_results[select_id].pop('summ')

return new_search_results

@tool_api
def open_url(self, url: str) -> dict:
print(f'Start Browsing: {url}')
web_success, web_content = self.fetcher.fetch(url)
if web_success:
return {'type': 'text', 'content': web_content}
else:
return {'error': web_content}
5 changes: 4 additions & 1 deletion lagent/actions/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def parse_outputs(self, outputs: Any) -> List[dict]:
outputs = json.dumps(outputs, ensure_ascii=False)
elif not isinstance(outputs, str):
outputs = str(outputs)
return [{'type': 'text', 'content': outputs}]
return [{
'type': 'text',
'content': outputs.encode('gbk', 'ignore').decode('gbk')
}]


class JsonParser(BaseParser):
Expand Down
20 changes: 15 additions & 5 deletions lagent/agents/internlm2_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from copy import deepcopy
from typing import Dict, List, Optional, Union

from termcolor import colored

from lagent.actions import ActionExecutor
from lagent.agents.base_agent import BaseAgent
from lagent.llms import BaseAPIModel, BaseModel
Expand Down Expand Up @@ -160,8 +162,10 @@ def format(self,
formatted += self.format_sub_role(inner_step)
return formatted

def parse(self, message, plugin_executor: ActionExecutor,
interpreter_executor: ActionExecutor):
def parse(self,
message,
plugin_executor: ActionExecutor = None,
interpreter_executor: ActionExecutor = None):
if self.language['begin']:
message = message.split(self.language['begin'])[-1]
if self.tool['name_map']['plugin'] in message:
Expand All @@ -183,9 +187,10 @@ def parse(self, message, plugin_executor: ActionExecutor,
message = message.strip()
code = code.split(self.tool['end'].strip())[0].strip()
return 'interpreter', message, dict(
name=interpreter_executor.action_names()[0],
parameters=dict(
command=code)) if interpreter_executor else None
name=interpreter_executor.action_names()[0] if isinstance(
interpreter_executor, ActionExecutor) else
'IPythonInterpreter',
parameters=dict(command=code))
return None, message.split(self.tool['start_token'])[0], None

def format_response(self, action_return, name) -> dict:
Expand Down Expand Up @@ -285,6 +290,7 @@ def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
inner_history = message[:]
offset = len(inner_history)
agent_return = AgentReturn()
agent_return.inner_steps = deepcopy(inner_history)
last_agent_state = AgentStatusCode.SESSION_READY
for _ in range(self.max_turn):
# list of dict
Expand Down Expand Up @@ -348,11 +354,14 @@ def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
agent_return.response = language
last_agent_state = agent_state
yield deepcopy(agent_return)
print(colored(response, 'red'))
if name:
action_return: ActionReturn = executor(action['name'],
action['parameters'])
action_return.type = action['name']
action_return.thought = language
agent_return.actions.append(action_return)
print(colored(action_return.result, 'magenta'))
inner_history.append(dict(role='language', content=language))
if not name:
agent_return.response = language
Expand All @@ -372,6 +381,7 @@ def stream_chat(self, message: List[dict], **kwargs) -> AgentReturn:
self._protocol.format_response(action_return, name=name))
agent_state += 1
agent_return.state = agent_state
agent_return.inner_steps = deepcopy(inner_history[offset:])
yield agent_return
agent_return.inner_steps = deepcopy(inner_history[offset:])
agent_return.state = AgentStatusCode.END
Expand Down
3 changes: 2 additions & 1 deletion lagent/llms/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def __init__(self,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words)
stop_words=stop_words,
skip_special_tokens=False)

def _wait(self):
"""Wait till the next query can be sent.
Expand Down
Loading

0 comments on commit 7167bbd

Please sign in to comment.