-
Notifications
You must be signed in to change notification settings - Fork 0
/
web_explorer.py
102 lines (80 loc) · 3.87 KB
/
web_explorer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.retrievers.web_research import WebResearchRetriever
import os
os.environ["GOOGLE_API_KEY"] = st.secrets["GOOGLE_API_KEY"] # Get it at https://console.cloud.google.com/apis/api/customsearch.googleapis.com/credentials
os.environ["GOOGLE_CSE_ID"] = st.secrets["GOOGLE_CSE_ID"] # Get it at https://programmablesearchengine.google.com/
os.environ["OPENAI_API_BASE"] = "https://api.openai.com/v1"
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] # Get it at https://beta.openai.com/account/api-keys
st.set_page_config(page_title="Interweb Explorer", page_icon="🌐")
def settings():
# Vectorstore
import faiss
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.docstore import InMemoryDocstore
embeddings_model = OpenAIEmbeddings()
embedding_size = 1536
# pinecone.init(
# api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
# environment=os.getenv("PINECONE_ENV"), # next to api key in console
# )
index = faiss.IndexFlatL2(embedding_size)
vectorstore_public = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
# LLM
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(model_name="gpt-3.5-turbo-16k", temperature=0, streaming=True)
# Search
from langchain.utilities import GoogleSearchAPIWrapper
search = GoogleSearchAPIWrapper()
# Initialize
web_retriever = WebResearchRetriever.from_llm(
vectorstore=vectorstore_public,
llm=llm,
search=search,
num_search_results=3
)
return web_retriever, llm
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.info(self.text)
class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
self.container = container.expander("Context Retrieval")
def on_retriever_start(self, query: str, **kwargs):
self.container.write(f"**Question:** {query}")
def on_retriever_end(self, documents, **kwargs):
# self.container.write(documents)
for idx, doc in enumerate(documents):
source = doc.metadata["source"]
self.container.write(f"**Results from {source}**")
self.container.text(doc.page_content)
# st.sidebar.image("img/ai.png")
st.header("`Interweb Explorer`")
st.info("`I am an AI that can answer questions by exploring, reading, and summarizing web pages."
"I can be configured to use different modes: public API or private (no data sharing).`")
# Make retriever and llm
if 'retriever' not in st.session_state:
st.session_state['retriever'], st.session_state['llm'] = settings()
web_retriever = st.session_state.retriever
llm = st.session_state.llm
# User input
question = st.text_input("`Ask a question:`")
if question:
# Generate answer (w/ citations)
import logging
logging.basicConfig()
logging.getLogger("langchain.retrievers.web_research").setLevel(logging.INFO)
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=web_retriever)
# Write answer and sources
retrieval_streamer_cb = PrintRetrievalHandler(st.container())
answer = st.empty()
stream_handler = StreamHandler(answer, initial_text="`Answer:`\n\n")
result = qa_chain({"question": question},callbacks=[retrieval_streamer_cb, stream_handler])
answer.info('`Answer:`\n\n' + result['answer'])
st.info('`Sources:`\n\n' + result['sources'])