Skip to content

Commit ea6cbba

Browse files
authored
Merge pull request #54 from Gentopia-AI/tool-fix
better than arxiv_search
2 parents d4e112b + ace29dd commit ea6cbba

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

gentopia/tools/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .basetool import BaseTool
22
from .google_search import GoogleSearch
3+
from .google_scholar import *
34
from .calculator import Calculator
45
from .wikipedia import Wikipedia
56
from .wolfram_alpha import WolframAlpha
@@ -33,6 +34,12 @@ def load_tools(name: str) -> BaseTool:
3334
"wikipedia": Wikipedia,
3435
"web_page": WebPage,
3536
"wolfram_alpha": WolframAlpha,
37+
"search_author_by_name": SearchAuthorByName,
38+
"search_author_by_interests": SearchAuthorByInterests,
39+
"author_uid2paper": AuthorUID2Paper,
40+
"search_paper": SearchPaper,
41+
"search_related_paper": SearchRelatedPaper,
42+
"search_cite_paper": SearchCitePaper,
3643
}
3744
if name not in name2tool:
3845
raise NotImplementedError

gentopia/tools/google_scholar.py

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from typing import AnyStr, List
2+
from scholarly import scholarly, ProxyGenerator
3+
from gentopia.tools.basetool import *
4+
from scholarly import ProxyGenerator
5+
from itertools import islice
6+
7+
8+
class SearchAuthorByNameArgs(BaseModel):
9+
author: str = Field(..., description="author name with the institute name (optional), e.g., Tan Lee")
10+
top_k: int = Field(..., description="number of results to display. 5 is prefered")
11+
12+
13+
class SearchAuthorByName(BaseTool):
14+
name = "search_author_by_name"
15+
description = ("search an author with google scholar."
16+
"input a name, return a list of authors with info (including uid)."
17+
"you can repeat calling the function to get next results."
18+
)
19+
args_schema: Optional[Type[BaseModel]] = SearchAuthorByNameArgs
20+
author: str = ""
21+
results: List = []
22+
23+
def _run(self, author: AnyStr, top_k: int = 5) -> str:
24+
if author != self.author:
25+
self.results = scholarly.search_author(author)
26+
self.author = author
27+
assert self.results is not None
28+
ans = []
29+
for it in islice(self.results, top_k):
30+
ans.append(str({
31+
'name': it["name"],
32+
'uid': it["scholar_id"],
33+
'affiliation': it["affiliation"],
34+
'interests': it['interests'],
35+
'citation': it['citedby'],
36+
}))
37+
if not ans:
38+
return "no furthur information available"
39+
return '\n\n'.join(ans)
40+
41+
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
42+
raise NotImplementedError
43+
44+
45+
class SearchAuthorByInterestsArgs(BaseModel):
46+
interests: str = Field(..., description="research interests separated by comma, e.g., 'crowdsourcing,privacy'")
47+
top_k: int = Field(..., description="number of results to display. 5 is prefered.")
48+
49+
50+
class SearchAuthorByInterests(BaseTool):
51+
name = "search_author_by_interests"
52+
description = ("search authors given keywords of research interests"
53+
"input interests, return a list of authors."
54+
"you can repeat calling the function to get next results."
55+
)
56+
args_schema: Optional[Type[BaseModel]] = SearchAuthorByInterestsArgs
57+
interests: str = ""
58+
results: List = []
59+
60+
def _run(self, interests: AnyStr, top_k: int = 5) -> str:
61+
if interests != self.interests:
62+
self.results = scholarly.search_keywords(interests.split(','))
63+
self.interests = interests
64+
assert self.results is not None
65+
ans = []
66+
for it in islice(self.results, top_k):
67+
ans.append(str({
68+
'name': it["name"],
69+
'uid': it['scholar_id'],
70+
'affiliation': it['affiliation'],
71+
'interests': it['interests'],
72+
'citation': it['citedby'],
73+
}))
74+
if not ans:
75+
return "no furthur information available"
76+
return '\n\n'.join(ans)
77+
78+
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
79+
raise NotImplementedError
80+
81+
82+
class AuthorUID2PaperArgs(BaseModel):
83+
uid: str = Field(..., description="a unique identifier assigned to author in Google scholar")
84+
sort_by: str = Field(..., description="either 'citedby' or 'year'.")
85+
top_k: int = Field(..., description="number of results to display. 5 is prefered")
86+
87+
88+
class AuthorUID2Paper(BaseTool):
89+
name = "author_uid2paper"
90+
description = ("search the papers given the UID of an author."
91+
"you can use search_author first to get UID."
92+
"you can repeat calling the function to get next results."
93+
)
94+
args_schema: Optional[Type[BaseModel]] = AuthorUID2PaperArgs
95+
uid: str = ""
96+
sort_by: str = ""
97+
results: List = []
98+
99+
def _run(self, uid: AnyStr, sort_by: AnyStr, top_k: int = 5) -> str:
100+
if uid != self.uid or sort_by != self.sort_by:
101+
author = scholarly.search_author_id(uid)
102+
author = scholarly.fill(author, sortby=sort_by)
103+
self.results = iter(author['publications'])
104+
self.uid = uid
105+
self.sort_by = sort_by
106+
assert self.results is not None
107+
ans = []
108+
for it in islice(self.results, top_k):
109+
ans.append(str({
110+
'title': it['bib']["title"],
111+
'pub_year': it['bib']['pub_year'],
112+
'venue': it['bib']['citation'],
113+
"abstract": it['bib']['abstract'],
114+
'url': it['pub_url'],
115+
'citation': it['num_citations'],
116+
}))
117+
if not ans:
118+
return "no furthur information available"
119+
return '\n\n'.join(ans)
120+
121+
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
122+
raise NotImplementedError
123+
124+
125+
class SearchPaperArgs(BaseModel):
126+
title: str = Field(..., description="title name")
127+
sort_by: str = Field(..., description="either 'relevance' or 'date'.")
128+
top_k: int = Field(..., description="number of results to display. 5 is prefered. set to 1 if given the complete title")
129+
130+
131+
class SearchPaper(BaseTool):
132+
name = "search_paper"
133+
description = ("search a paper with the title relevant to the input text."
134+
"input text query, return a list of papers."
135+
"you can repeat calling the function to get next results."
136+
)
137+
args_schema: Optional[Type[BaseModel]] = SearchPaperArgs
138+
title: str = ""
139+
sort_by: str = ""
140+
results: List = []
141+
142+
def _run(self, title: AnyStr, sort_by: AnyStr, top_k: int = 5) -> str:
143+
if title != self.title or sort_by != self.sort_by:
144+
self.results = scholarly.search_pubs(title, sort_by=sort_by)
145+
self.title = title
146+
self.sort_by = sort_by
147+
assert self.results is not None
148+
ans = []
149+
for it in islice(self.results, top_k):
150+
ans.append(str({
151+
'title': it['bib']["title"],
152+
'author': it['bib']['author'],
153+
'pub_year': it['bib']['pub_year'],
154+
'venue': it['bib']['venue'],
155+
"abstract": it['bib']['abstract'],
156+
'url': it['pub_url'],
157+
'citation': it['num_citations'],
158+
}))
159+
if not ans:
160+
return "no furthur information available"
161+
return '\n\n'.join(ans)
162+
163+
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
164+
raise NotImplementedError
165+
166+
167+
class SearchRelatedPaperArgs(BaseModel):
168+
title: str = Field(..., description="title name")
169+
top_k: int = Field(..., description="number of results to display. 5 is prefered.")
170+
171+
172+
class SearchRelatedPaper(BaseTool):
173+
name = "search_related_paper"
174+
description = ("search the papers related to the target one."
175+
"input the complete paper title, return a list of relevant papers."
176+
"you can repeat calling the function to get next results."
177+
)
178+
args_schema: Optional[Type[BaseModel]] = SearchRelatedPaperArgs
179+
title: str = ""
180+
results: List = []
181+
182+
def _run(self, title: AnyStr, top_k: int = 5) -> str:
183+
if title != self.title:
184+
# please make sure the title is complete
185+
paper = scholarly.search_single_pub(title)
186+
self.results = scholarly.get_related_articles(paper)
187+
self.title = title
188+
assert self.results is not None
189+
ans = []
190+
for it in islice(self.results, top_k):
191+
ans.append(str({
192+
'title': it['bib']["title"],
193+
'author': it['bib']['author'],
194+
'pub_year': it['bib']['pub_year'],
195+
'venue': it['bib']['venue'],
196+
"abstract": it['bib']['abstract'],
197+
'url': it['pub_url'],
198+
'citation': it['num_citations'],
199+
}))
200+
if not ans:
201+
return "no furthur information available"
202+
return '\n\n'.join(ans)
203+
204+
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
205+
raise NotImplementedError
206+
207+
208+
class SearchCitePaperArgs(BaseModel):
209+
title: str = Field(..., description="title name")
210+
top_k: int = Field(..., description="number of results to display. 5 is prefered.")
211+
212+
213+
class SearchCitePaper(BaseTool):
214+
name = "search_cite_paper"
215+
description = ("search the papers citing to the target one."
216+
"input the complete paper title, return a list of papers citing the one."
217+
"you can repeat calling the function to get next results."
218+
)
219+
args_schema: Optional[Type[BaseModel]] = SearchCitePaperArgs
220+
title: str = ""
221+
results: List = []
222+
223+
def _run(self, title: AnyStr, top_k: int = 5) -> str:
224+
if title != self.title:
225+
# please make sure the title is complete
226+
paper = scholarly.search_single_pub(title)
227+
self.results = scholarly.citedby(paper)
228+
self.title = title
229+
assert self.results is not None
230+
ans = []
231+
for it in islice(self.results, top_k):
232+
ans.append(str({
233+
'title': it['bib']["title"],
234+
'author': it['bib']['author'],
235+
'pub_year': it['bib']['pub_year'],
236+
'venue': it['bib']['venue'],
237+
"abstract": it['bib']['abstract'],
238+
'url': it['pub_url'],
239+
'citation': it['num_citations'],
240+
}))
241+
if not ans:
242+
return "no furthur information available"
243+
return '\n\n'.join(ans)
244+
245+
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
246+
raise NotImplementedError
247+
248+
249+
if __name__ == "__main__":
250+
import pdb
251+
searcher1 = SearchAuthorByName()
252+
ans1 = searcher1._run("Tan Lee")
253+
# ans2 = searcher1._run("Tan Lee")
254+
# searcher3 = AuthorUID2Paper()
255+
# searcher3._run("5VTS11IAAAAJ")
256+
# searcher4 = SearchPaper()
257+
# ans4 = searcher4._run("Attention is all you need")
258+
# searcher5 = SearchAuthorByInterests()
259+
# searcher5._run("privacy,robustness")
260+
# print(ans4)
261+
# pdb.set_trace()
262+
search6 = SearchRelatedPaper()
263+
ans = search6._run("Attention is all you need")
264+
# search7 = SearchCitePaper()
265+
# ans = search7._run("Attention is all you need")
266+
print(ans)

0 commit comments

Comments
 (0)