-
Notifications
You must be signed in to change notification settings - Fork 43
/
chatbot.py
117 lines (100 loc) · 4.13 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# chatbot.py
import os
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Qdrant
from langchain_ollama import ChatOllama
from qdrant_client import QdrantClient
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
import streamlit as st
class ChatbotManager:
def __init__(
self,
model_name: str = "BAAI/bge-small-en",
device: str = "cpu",
encode_kwargs: dict = {"normalize_embeddings": True},
llm_model: str = "llama3.2:3b",
llm_temperature: float = 0.7,
qdrant_url: str = "http://localhost:6333",
collection_name: str = "vector_db",
):
"""
Initializes the ChatbotManager with embedding models, LLM, and vector store.
Args:
model_name (str): The HuggingFace model name for embeddings.
device (str): The device to run the model on ('cpu' or 'cuda').
encode_kwargs (dict): Additional keyword arguments for encoding.
llm_model (str): The local LLM model name for ChatOllama.
llm_temperature (float): Temperature setting for the LLM.
qdrant_url (str): The URL for the Qdrant instance.
collection_name (str): The name of the Qdrant collection.
"""
self.model_name = model_name
self.device = device
self.encode_kwargs = encode_kwargs
self.llm_model = llm_model
self.llm_temperature = llm_temperature
self.qdrant_url = qdrant_url
self.collection_name = collection_name
# Initialize Embeddings
self.embeddings = HuggingFaceBgeEmbeddings(
model_name=self.model_name,
model_kwargs={"device": self.device},
encode_kwargs=self.encode_kwargs,
)
# Initialize Local LLM
self.llm = ChatOllama(
model=self.llm_model,
temperature=self.llm_temperature,
# Add other parameters if needed
)
# Define the prompt template
self.prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer. Answer must be detailed and well explained.
Helpful answer:
"""
# Initialize Qdrant client
self.client = QdrantClient(
url=self.qdrant_url, prefer_grpc=False
)
# Initialize the Qdrant vector store
self.db = Qdrant(
client=self.client,
embeddings=self.embeddings,
collection_name=self.collection_name
)
# Initialize the prompt
self.prompt = PromptTemplate(
template=self.prompt_template,
input_variables=['context', 'question']
)
# Initialize the retriever
self.retriever = self.db.as_retriever(search_kwargs={"k": 1})
# Define chain type kwargs
self.chain_type_kwargs = {"prompt": self.prompt}
# Initialize the RetrievalQA chain with return_source_documents=False
self.qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.retriever,
return_source_documents=False, # Set to False to return only 'result'
chain_type_kwargs=self.chain_type_kwargs,
verbose=False
)
def get_response(self, query: str) -> str:
"""
Processes the user's query and returns the chatbot's response.
Args:
query (str): The user's input question.
Returns:
str: The chatbot's response.
"""
try:
response = self.qa.run(query)
return response # 'response' is now a string containing only the 'result'
except Exception as e:
st.error(f"⚠️ An error occurred while processing your request: {e}")
return "⚠️ Sorry, I couldn't process your request at the moment."