From 04318ceaff8d44327b1142be44595772a58f9a46 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 10 Apr 2024 08:45:26 +0800 Subject: [PATCH] Add source to the answer for default prompt (#2289) * Add source to the answer for default prompt * Fix qdrant * Fix tests * Update docstring * Fix check files * Fix qdrant test error --- .../qdrant_retrieve_user_proxy_agent.py | 5 ++-- .../contrib/retrieve_user_proxy_agent.py | 16 ++++++++++-- autogen/retrieve_utils.py | 25 +++++++++++++------ test/test_retrieve_utils.py | 10 ++++---- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index 3f5278c5f8ba..c68ce809d8d8 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -190,12 +190,12 @@ def create_qdrant_from_dir( client.set_model(embedding_model) if custom_text_split_function is not None: - chunks = split_files_to_chunks( + chunks, sources = split_files_to_chunks( get_files_from_dir(dir_path, custom_text_types, recursive), custom_text_split_function=custom_text_split_function, ) else: - chunks = split_files_to_chunks( + chunks, sources = split_files_to_chunks( get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line ) logger.info(f"Found {len(chunks)} chunks.") @@ -298,5 +298,6 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore data = { "ids": [[result.id for result in sublist] for sublist in results], "documents": [[result.document for result in sublist] for sublist in results], + "metadatas": [[result.metadata for result in sublist] for sublist in results], } return data diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index ddc70e06356e..34dbe28d0984 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -34,6 +34,10 @@ User's question is: {input_question} Context is: {input_context} + +The source of the context is: {input_sources} + +If you can answer the question, in the end of your answer, add the source of the context in the format of `Sources: source1, source2, ...`. """ PROMPT_CODE = """You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the @@ -101,7 +105,8 @@ def __init__( following keys: - `task` (Optional, str) - the task of the retrieve chat. Possible values are "code", "qa" and "default". System prompt will be different for different tasks. - The default value is `default`, which supports both code and qa. + The default value is `default`, which supports both code and qa, and provides + source information in the end of the response. - `client` (Optional, chromadb.Client) - the chromadb client. If key not provided, a default client `chromadb.Client()` will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. @@ -243,6 +248,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc + self._current_docs_in_context = [] # the ids of the current context sources self._search_string = "" # the search string used in the current query # update the termination message function self._is_termination_msg = ( @@ -290,6 +296,7 @@ def _reset(self, intermediate=False): def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]): doc_contents = "" + self._current_docs_in_context = [] current_tokens = 0 _doc_idx = self._doc_idx _tmp_retrieve_count = 0 @@ -310,6 +317,9 @@ def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]): print(colored(func_print, "green"), flush=True) current_tokens += _doc_tokens doc_contents += doc + "\n" + _metadatas = results.get("metadatas") + if isinstance(_metadatas, list) and isinstance(_metadatas[0][idx], dict): + self._current_docs_in_context.append(results["metadatas"][0][idx].get("source", "")) self._doc_idx = idx self._doc_ids.append(results["ids"][0][idx]) self._doc_contents.append(doc) @@ -329,7 +339,9 @@ def _generate_message(self, doc_contents, task="default"): elif task.upper() == "QA": message = PROMPT_QA.format(input_question=self.problem, input_context=doc_contents) elif task.upper() == "DEFAULT": - message = PROMPT_DEFAULT.format(input_question=self.problem, input_context=doc_contents) + message = PROMPT_DEFAULT.format( + input_question=self.problem, input_context=doc_contents, input_sources=self._current_docs_in_context + ) else: raise NotImplementedError(f"task {task} is not implemented.") return message diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 6db47b1d3123..e83f8a80f36b 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -1,7 +1,7 @@ import glob import os import re -from typing import Callable, List, Union +from typing import Callable, List, Tuple, Union from urllib.parse import urlparse import chromadb @@ -160,8 +160,14 @@ def split_files_to_chunks( """Split a list of files into chunks of max_tokens.""" chunks = [] + sources = [] for file in files: + if isinstance(file, tuple): + url = file[1] + file = file[0] + else: + url = None _, file_extension = os.path.splitext(file) file_extension = file_extension.lower() @@ -179,11 +185,13 @@ def split_files_to_chunks( continue # Skip to the next file if no text is available if custom_text_split_function is not None: - chunks += custom_text_split_function(text) + tmp_chunks = custom_text_split_function(text) else: - chunks += split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line) + tmp_chunks = split_text_to_chunks(text, max_tokens, chunk_mode, must_break_at_empty_line) + chunks += tmp_chunks + sources += [{"source": url if url else file}] * len(tmp_chunks) - return chunks + return chunks, sources def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True): @@ -267,7 +275,7 @@ def parse_html_to_markdown(html: str, url: str = None) -> str: return webpage_text -def get_file_from_url(url: str, save_path: str = None): +def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]: """Download a file from a URL.""" if save_path is None: save_path = "tmp/chromadb" @@ -303,7 +311,7 @@ def get_file_from_url(url: str, save_path: str = None): with open(save_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) - return save_path + return save_path, url def is_url(string: str): @@ -383,12 +391,12 @@ def create_vector_db_from_dir( length = len(collection.get()["ids"]) if custom_text_split_function is not None: - chunks = split_files_to_chunks( + chunks, sources = split_files_to_chunks( get_files_from_dir(dir_path, custom_text_types, recursive), custom_text_split_function=custom_text_split_function, ) else: - chunks = split_files_to_chunks( + chunks, sources = split_files_to_chunks( get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, @@ -401,6 +409,7 @@ def create_vector_db_from_dir( collection.upsert( documents=chunks[i:end_idx], ids=[f"doc_{j+length}" for j in range(i, end_idx)], # unique for each doc + metadatas=sources[i:end_idx], ) except ValueError as e: logger.warning(f"{e}") diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index a7235f10bd96..a8c384555577 100755 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -69,7 +69,7 @@ def test_extract_text_from_pdf(self): def test_split_files_to_chunks(self): pdf_file_path = os.path.join(test_dir, "example.pdf") txt_file_path = os.path.join(test_dir, "example.txt") - chunks = split_files_to_chunks([pdf_file_path, txt_file_path]) + chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path]) assert all( isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip() for chunk in chunks @@ -81,7 +81,7 @@ def test_get_files_from_dir(self): pdf_file_path = os.path.join(test_dir, "example.pdf") txt_file_path = os.path.join(test_dir, "example.txt") files = get_files_from_dir([pdf_file_path, txt_file_path]) - assert all(os.path.isfile(file) for file in files) + assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files) files = get_files_from_dir( [ pdf_file_path, @@ -91,7 +91,7 @@ def test_get_files_from_dir(self): ], recursive=True, ) - assert all(os.path.isfile(file) for file in files) + assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files) files = get_files_from_dir( [ pdf_file_path, @@ -102,7 +102,7 @@ def test_get_files_from_dir(self): recursive=True, types=["pdf", "txt"], ) - assert all(os.path.isfile(file) for file in files) + assert all(os.path.isfile(file) if isinstance(file, str) else os.path.isfile(file[0]) for file in files) assert len(files) == 3 def test_is_url(self): @@ -243,7 +243,7 @@ def test_unstructured(self): pdf_file_path = os.path.join(test_dir, "example.pdf") txt_file_path = os.path.join(test_dir, "example.txt") word_file_path = os.path.join(test_dir, "example.docx") - chunks = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path]) + chunks, _ = split_files_to_chunks([pdf_file_path, txt_file_path, word_file_path]) assert all( isinstance(chunk, str) and "AutoGen is an advanced tool designed to assist developers" in chunk.strip() for chunk in chunks