Skip to content

Commit

Permalink
[Fix] Alternative for aioify (#274)
Browse files Browse the repository at this point in the history
replace `aioify` with `asyncify`
  • Loading branch information
braisedpork1964 authored Nov 18, 2024
1 parent e72713a commit b2bf23d
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 123 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ repos:
hooks:
- id: black
args: ["--line-length", "119", "--skip-string-normalization"]


- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
12 changes: 5 additions & 7 deletions lagent/actions/arxiv_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Type

from aioify import aioify
from asyncer import asyncify

from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
Expand Down Expand Up @@ -42,12 +42,10 @@ def get_arxiv_article_information(self, query: str) -> dict:

try:
results = arxiv.Search( # type: ignore
query[:self.max_query_len],
max_results=self.top_k_results).results()
query[: self.max_query_len], max_results=self.top_k_results
).results()
except Exception as exc:
return ActionReturn(
errmsg=f'Arxiv exception: {exc}',
state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
docs = [
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
f'Authors: {", ".join(a.name for a in result.authors)}\n'
Expand All @@ -67,7 +65,7 @@ class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
"""

@tool_api(explode_return=True)
@aioify
@asyncify
def get_arxiv_article_information(self, query: str) -> dict:
"""Run Arxiv search and get the article meta information.
Expand Down
185 changes: 102 additions & 83 deletions lagent/actions/google_scholar_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import Optional, Type

from aioify import aioify
from asyncer import asyncify

from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
from lagent.schema import ActionReturn, ActionStatusCode
Expand Down Expand Up @@ -31,7 +31,8 @@ def __init__(
if api_key is None:
raise ValueError(
'Please set Serper API key either in the environment '
'as SERPER_API_KEY or pass it as `api_key` parameter.')
'as SERPER_API_KEY or pass it as `api_key` parameter.'
)
self.api_key = api_key

@tool_api(explode_return=True)
Expand Down Expand Up @@ -78,6 +79,7 @@ def search_google_scholar(
- pub_info: publication information of selected papers
"""
from serpapi import GoogleSearch

params = {
'q': query,
'engine': 'google_scholar',
Expand All @@ -94,7 +96,7 @@ def search_google_scholar(
'as_sdt': as_sdt,
'safe': safe,
'filter': filter,
'as_vis': as_vis
'as_vis': as_vis,
}
search = GoogleSearch(params)
try:
Expand All @@ -112,27 +114,24 @@ def search_google_scholar(
cited_by.append(citation['total'])
snippets.append(item['snippet'])
organic_id.append(item['result_id'])
return dict(
title=title,
cited_by=cited_by,
organic_id=organic_id,
snippets=snippets)
return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(explode_return=True)
def get_author_information(self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None) -> dict:
def get_author_information(
self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None,
) -> dict:
"""Search for an author's information by author's id provided by get_author_id.
Args:
Expand All @@ -155,6 +154,7 @@ def get_author_information(self,
* website: the author's homepage url
"""
from serpapi import GoogleSearch

params = {
'engine': 'google_scholar_author',
'author_id': author_id,
Expand All @@ -167,7 +167,7 @@ def get_author_information(self,
'num': num,
'no_cache': no_cache,
'async': async_req,
'output': output
'output': output,
}
try:
search = GoogleSearch(params)
Expand All @@ -178,20 +178,19 @@ def get_author_information(self,
name=author['name'],
affiliations=author.get('affiliations', ''),
website=author.get('website', ''),
articles=[
dict(title=article['title'], authors=article['authors'])
for article in articles[:3]
])
articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
)
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(explode_return=True)
def get_citation_format(self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> dict:
def get_citation_format(
self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json',
) -> dict:
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
Args:
Expand All @@ -206,13 +205,14 @@ def get_citation_format(self,
* citation: the citation format of the article
"""
from serpapi import GoogleSearch

params = {
'q': q,
'engine': 'google_scholar_cite',
'api_key': self.api_key,
'no_cache': no_cache,
'async': async_,
'output': output
'output': output,
}
try:
search = GoogleSearch(params)
Expand All @@ -221,18 +221,19 @@ def get_citation_format(self,
citation_info = citation[0]['snippet']
return citation_info
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(explode_return=True)
def get_author_id(self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json') -> dict:
def get_author_id(
self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json',
) -> dict:
"""The getAuthorId function is used to get the author's id by his or her name.
Args:
Expand All @@ -249,6 +250,7 @@ def get_author_id(self,
* author_id: the author_id of the author
"""
from serpapi import GoogleSearch

params = {
'mauthors': mauthors,
'engine': 'google_scholar_profiles',
Expand All @@ -258,7 +260,7 @@ def get_author_id(self,
'before_author': before_author,
'no_cache': no_cache,
'async': _async,
'output': output
'output': output,
}
try:
search = GoogleSearch(params)
Expand All @@ -267,8 +269,7 @@ def get_author_id(self,
author_info = dict(author_id=profile[0]['author_id'])
return author_info
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)


class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
Expand All @@ -283,7 +284,7 @@ class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
"""

@tool_api(explode_return=True)
@aioify
@asyncify
def search_google_scholar(
self,
query: str,
Expand Down Expand Up @@ -326,23 +327,38 @@ def search_google_scholar(
- organic_id: a list of the organic results' ids of the three selected papers
- pub_info: publication information of selected papers
"""
return super().search_google_scholar(query, cites, as_ylo, as_yhi,
scisbd, cluster, hl, lr, start,
num, as_sdt, safe, filter, as_vis)
return super().search_google_scholar(
query,
cites,
as_ylo,
as_yhi,
scisbd,
cluster,
hl,
lr,
start,
num,
as_sdt,
safe,
filter,
as_vis,
)

@tool_api(explode_return=True)
@aioify
def get_author_information(self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None) -> dict:
@asyncify
def get_author_information(
self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None,
) -> dict:
"""Search for an author's information by author's id provided by get_author_id.
Args:
Expand All @@ -364,17 +380,19 @@ def get_author_information(self,
* articles: at most 3 articles by the author
* website: the author's homepage url
"""
return super().get_author_information(author_id, hl, view_op, sort,
citation_id, start, num,
no_cache, async_req, output)
return super().get_author_information(
author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
)

@tool_api(explode_return=True)
@aioify
def get_citation_format(self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> dict:
@asyncify
def get_citation_format(
self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json',
) -> dict:
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
Args:
Expand All @@ -391,15 +409,17 @@ def get_citation_format(self,
return super().get_citation_format(q, no_cache, async_, output)

@tool_api(explode_return=True)
@aioify
def get_author_id(self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json') -> dict:
@asyncify
def get_author_id(
self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json',
) -> dict:
"""The getAuthorId function is used to get the author's id by his or her name.
Args:
Expand All @@ -415,5 +435,4 @@ def get_author_id(self,
:class:`dict`: author id
* author_id: the author_id of the author
"""
return super().get_author_id(mauthors, hl, after_author, before_author,
no_cache, _async, output)
return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)
Loading

0 comments on commit b2bf23d

Please sign in to comment.