From 20c3aee636de7a7041f5c63126ceb80a8014de9e Mon Sep 17 00:00:00 2001 From: Vatsal Rathod Date: Tue, 15 Oct 2024 07:26:35 -0400 Subject: [PATCH] Adding fetching data functionality for reference links in the web page (#1806) --- embedchain/docs/api-reference/app/add.mdx | 3 ++ .../embedchain/chunkers/base_chunker.py | 13 ++++++-- .../data_formatter/data_formatter.py | 10 ++++-- embedchain/embedchain/embedchain.py | 2 +- embedchain/embedchain/loaders/base_loader.py | 4 ++- embedchain/embedchain/loaders/web_page.py | 28 ++++++++++++++++- embedchain/embedchain/vectordb/chroma.py | 1 + embedchain/poetry.lock | 2 ++ embedchain/tests/loaders/test_web_page.py | 31 +++++++++++++++++++ 9 files changed, 86 insertions(+), 8 deletions(-) diff --git a/embedchain/docs/api-reference/app/add.mdx b/embedchain/docs/api-reference/app/add.mdx index 4586889843..21b24de345 100644 --- a/embedchain/docs/api-reference/app/add.mdx +++ b/embedchain/docs/api-reference/app/add.mdx @@ -15,6 +15,9 @@ title: '📊 add' Any metadata that you want to store with the data source. Metadata is generally really useful for doing metadata filtering on top of semantic search to yield faster search and better results. + + This parameter instructs Embedchain to retrieve all the context and information from the specified link, as well as from any reference links on the page. + ## Usage diff --git a/embedchain/embedchain/chunkers/base_chunker.py b/embedchain/embedchain/chunkers/base_chunker.py index 274ae54759..1f04a7d3f6 100644 --- a/embedchain/embedchain/chunkers/base_chunker.py +++ b/embedchain/embedchain/chunkers/base_chunker.py @@ -1,6 +1,6 @@ import hashlib import logging -from typing import Optional +from typing import Any, Optional from embedchain.config.add_config import ChunkerConfig from embedchain.helpers.json_serializable import JSONSerializable @@ -15,7 +15,14 @@ def __init__(self, text_splitter): self.text_splitter = text_splitter self.data_type = None - def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None): + def create_chunks( + self, + loader, + src, + app_id=None, + config: Optional[ChunkerConfig] = None, + **kwargs: Optional[dict[str, Any]], + ): """ Loads data and chunks it. @@ -30,7 +37,7 @@ def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig id_map = {} min_chunk_size = config.min_chunk_size if config is not None else 1 logger.info(f"Skipping chunks smaller than {min_chunk_size} characters") - data_result = loader.load_data(src) + data_result = loader.load_data(src, **kwargs) data_records = data_result["data"] doc_id = data_result["doc_id"] # Prefix app_id in the document id if app_id is not None to diff --git a/embedchain/embedchain/data_formatter/data_formatter.py b/embedchain/embedchain/data_formatter/data_formatter.py index 398d16a3fa..72923888dc 100644 --- a/embedchain/embedchain/data_formatter/data_formatter.py +++ b/embedchain/embedchain/data_formatter/data_formatter.py @@ -1,5 +1,5 @@ from importlib import import_module -from typing import Optional +from typing import Any, Optional from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig @@ -40,7 +40,13 @@ def _lazy_load(module_path: str): module = import_module(module_path) return getattr(module, class_name) - def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader: + def _get_loader( + self, + data_type: DataType, + config: LoaderConfig, + loader: Optional[BaseLoader], + **kwargs: Optional[dict[str, Any]], + ) -> BaseLoader: """ Returns the appropriate data loader for the given data type. diff --git a/embedchain/embedchain/embedchain.py b/embedchain/embedchain/embedchain.py index f7271d501a..4a1a4dc09c 100644 --- a/embedchain/embedchain/embedchain.py +++ b/embedchain/embedchain/embedchain.py @@ -329,7 +329,7 @@ def _load_and_embed( app_id = self.config.id if self.config is not None else None # Create chunks - embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker) + embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs) # spread chunking results documents = embeddings_data["documents"] metadatas = embeddings_data["metadatas"] diff --git a/embedchain/embedchain/loaders/base_loader.py b/embedchain/embedchain/loaders/base_loader.py index c45282e098..9dccfd539d 100644 --- a/embedchain/embedchain/loaders/base_loader.py +++ b/embedchain/embedchain/loaders/base_loader.py @@ -1,3 +1,5 @@ +from typing import Any, Optional + from embedchain.helpers.json_serializable import JSONSerializable @@ -5,7 +7,7 @@ class BaseLoader(JSONSerializable): def __init__(self): pass - def load_data(self, url): + def load_data(self, url, **kwargs: Optional[dict[str, Any]]): """ Implemented by child classes """ diff --git a/embedchain/embedchain/loaders/web_page.py b/embedchain/embedchain/loaders/web_page.py index e3dd5991d3..848bc20388 100644 --- a/embedchain/embedchain/loaders/web_page.py +++ b/embedchain/embedchain/loaders/web_page.py @@ -1,5 +1,6 @@ import hashlib import logging +from typing import Any, Optional import requests @@ -22,14 +23,29 @@ class WebPageLoader(BaseLoader): # Shared session for all instances _session = requests.Session() - def load_data(self, url): + def load_data(self, url, **kwargs: Optional[dict[str, Any]]): """Load data from a web page using a shared requests' session.""" + all_references = False + for key, value in kwargs.items(): + if key == "all_references": + all_references = kwargs["all_references"] headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501 } response = self._session.get(url, headers=headers, timeout=30) response.raise_for_status() data = response.content + reference_links = self.fetch_reference_links(response) + if all_references: + for i in reference_links: + try: + response = self._session.get(i, headers=headers, timeout=30) + response.raise_for_status() + data += response.content + except Exception as e: + logging.error(f"Failed to add URL {url}: {e}") + continue + content = self._get_clean_content(data, url) metadata = {"url": url} @@ -98,3 +114,13 @@ def _get_clean_content(html, url) -> str: @classmethod def close_session(cls): cls._session.close() + + def fetch_reference_links(self, response): + if response.status_code == 200: + soup = BeautifulSoup(response.content, "html.parser") + a_tags = soup.find_all("a", href=True) + reference_links = [a["href"] for a in a_tags if a["href"].startswith("http")] + return reference_links + else: + print(f"Failed to retrieve the page. Status code: {response.status_code}") + return [] diff --git a/embedchain/embedchain/vectordb/chroma.py b/embedchain/embedchain/vectordb/chroma.py index de73397e22..e8d3e3336b 100644 --- a/embedchain/embedchain/vectordb/chroma.py +++ b/embedchain/embedchain/vectordb/chroma.py @@ -136,6 +136,7 @@ def add( documents: list[str], metadatas: list[object], ids: list[str], + **kwargs: Optional[dict[str, Any]], ) -> Any: """ Add vectors to chroma database diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock index 4784f8dc0f..1edf130d35 100644 --- a/embedchain/poetry.lock +++ b/embedchain/poetry.lock @@ -2644,6 +2644,7 @@ description = "Client library to connect to the LangSmith LLM Tracing and Evalua optional = false python-versions = "<4.0,>=3.8.1" files = [ + {file = "langsmith-0.1.126-py3-none-any.whl", hash = "sha256:16c38ba5dae37a3cc715b6bc5d87d9579228433c2f34d6fa328345ee2b2bcc2a"}, {file = "langsmith-0.1.126.tar.gz", hash = "sha256:40f72e2d1d975473dd69269996053122941c1252915bcea55787607e2a7f949a"}, ] @@ -6750,3 +6751,4 @@ weaviate = ["weaviate-client"] lock-version = "2.0" python-versions = ">=3.9,<=3.13" content-hash = "ec8a87e5281b7fa0c2c28f24c2562e823f0c546a24da2bb285b2f239b7b1758d" + diff --git a/embedchain/tests/loaders/test_web_page.py b/embedchain/tests/loaders/test_web_page.py index 3134d4d0c1..46036ee200 100644 --- a/embedchain/tests/loaders/test_web_page.py +++ b/embedchain/tests/loaders/test_web_page.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest +import requests from embedchain.loaders.web_page import WebPageLoader @@ -115,3 +116,33 @@ def test_get_clean_content_excludes_unnecessary_info(web_page_loader): assert class_name not in content assert len(content) > 0 + + +def test_fetch_reference_links_success(web_page_loader): + # Mock a successful response + response = Mock(spec=requests.Response) + response.status_code = 200 + response.content = b""" + + + Example + Another Example + Relative Link + + + """ + + expected_links = ["http://example.com", "https://another-example.com"] + result = web_page_loader.fetch_reference_links(response) + assert result == expected_links + + +def test_fetch_reference_links_failure(web_page_loader): + # Mock a failed response + response = Mock(spec=requests.Response) + response.status_code = 404 + response.content = b"" + + expected_links = [] + result = web_page_loader.fetch_reference_links(response) + assert result == expected_links