diff --git a/embedchain/docs/api-reference/app/add.mdx b/embedchain/docs/api-reference/app/add.mdx
index 4586889843..21b24de345 100644
--- a/embedchain/docs/api-reference/app/add.mdx
+++ b/embedchain/docs/api-reference/app/add.mdx
@@ -15,6 +15,9 @@ title: '📊 add'
Any metadata that you want to store with the data source. Metadata is generally really useful for doing metadata filtering on top of semantic search to yield faster search and better results.
+
+ This parameter instructs Embedchain to retrieve all the context and information from the specified link, as well as from any reference links on the page.
+
## Usage
diff --git a/embedchain/embedchain/chunkers/base_chunker.py b/embedchain/embedchain/chunkers/base_chunker.py
index 274ae54759..1f04a7d3f6 100644
--- a/embedchain/embedchain/chunkers/base_chunker.py
+++ b/embedchain/embedchain/chunkers/base_chunker.py
@@ -1,6 +1,6 @@
import hashlib
import logging
-from typing import Optional
+from typing import Any, Optional
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable
@@ -15,7 +15,14 @@ def __init__(self, text_splitter):
self.text_splitter = text_splitter
self.data_type = None
- def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
+ def create_chunks(
+ self,
+ loader,
+ src,
+ app_id=None,
+ config: Optional[ChunkerConfig] = None,
+ **kwargs: Optional[dict[str, Any]],
+ ):
"""
Loads data and chunks it.
@@ -30,7 +37,7 @@ def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig
id_map = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
- data_result = loader.load_data(src)
+ data_result = loader.load_data(src, **kwargs)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
# Prefix app_id in the document id if app_id is not None to
diff --git a/embedchain/embedchain/data_formatter/data_formatter.py b/embedchain/embedchain/data_formatter/data_formatter.py
index 398d16a3fa..72923888dc 100644
--- a/embedchain/embedchain/data_formatter/data_formatter.py
+++ b/embedchain/embedchain/data_formatter/data_formatter.py
@@ -1,5 +1,5 @@
from importlib import import_module
-from typing import Optional
+from typing import Any, Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
@@ -40,7 +40,13 @@ def _lazy_load(module_path: str):
module = import_module(module_path)
return getattr(module, class_name)
- def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
+ def _get_loader(
+ self,
+ data_type: DataType,
+ config: LoaderConfig,
+ loader: Optional[BaseLoader],
+ **kwargs: Optional[dict[str, Any]],
+ ) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
diff --git a/embedchain/embedchain/embedchain.py b/embedchain/embedchain/embedchain.py
index f7271d501a..4a1a4dc09c 100644
--- a/embedchain/embedchain/embedchain.py
+++ b/embedchain/embedchain/embedchain.py
@@ -329,7 +329,7 @@ def _load_and_embed(
app_id = self.config.id if self.config is not None else None
# Create chunks
- embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
+ embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
diff --git a/embedchain/embedchain/loaders/base_loader.py b/embedchain/embedchain/loaders/base_loader.py
index c45282e098..9dccfd539d 100644
--- a/embedchain/embedchain/loaders/base_loader.py
+++ b/embedchain/embedchain/loaders/base_loader.py
@@ -1,3 +1,5 @@
+from typing import Any, Optional
+
from embedchain.helpers.json_serializable import JSONSerializable
@@ -5,7 +7,7 @@ class BaseLoader(JSONSerializable):
def __init__(self):
pass
- def load_data(self, url):
+ def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
"""
Implemented by child classes
"""
diff --git a/embedchain/embedchain/loaders/web_page.py b/embedchain/embedchain/loaders/web_page.py
index e3dd5991d3..848bc20388 100644
--- a/embedchain/embedchain/loaders/web_page.py
+++ b/embedchain/embedchain/loaders/web_page.py
@@ -1,5 +1,6 @@
import hashlib
import logging
+from typing import Any, Optional
import requests
@@ -22,14 +23,29 @@ class WebPageLoader(BaseLoader):
# Shared session for all instances
_session = requests.Session()
- def load_data(self, url):
+ def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
"""Load data from a web page using a shared requests' session."""
+ all_references = False
+ for key, value in kwargs.items():
+ if key == "all_references":
+ all_references = kwargs["all_references"]
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
}
response = self._session.get(url, headers=headers, timeout=30)
response.raise_for_status()
data = response.content
+ reference_links = self.fetch_reference_links(response)
+ if all_references:
+ for i in reference_links:
+ try:
+ response = self._session.get(i, headers=headers, timeout=30)
+ response.raise_for_status()
+ data += response.content
+ except Exception as e:
+ logging.error(f"Failed to add URL {url}: {e}")
+ continue
+
content = self._get_clean_content(data, url)
metadata = {"url": url}
@@ -98,3 +114,13 @@ def _get_clean_content(html, url) -> str:
@classmethod
def close_session(cls):
cls._session.close()
+
+ def fetch_reference_links(self, response):
+ if response.status_code == 200:
+ soup = BeautifulSoup(response.content, "html.parser")
+ a_tags = soup.find_all("a", href=True)
+ reference_links = [a["href"] for a in a_tags if a["href"].startswith("http")]
+ return reference_links
+ else:
+ print(f"Failed to retrieve the page. Status code: {response.status_code}")
+ return []
diff --git a/embedchain/embedchain/vectordb/chroma.py b/embedchain/embedchain/vectordb/chroma.py
index de73397e22..e8d3e3336b 100644
--- a/embedchain/embedchain/vectordb/chroma.py
+++ b/embedchain/embedchain/vectordb/chroma.py
@@ -136,6 +136,7 @@ def add(
documents: list[str],
metadatas: list[object],
ids: list[str],
+ **kwargs: Optional[dict[str, Any]],
) -> Any:
"""
Add vectors to chroma database
diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock
index 4784f8dc0f..1edf130d35 100644
--- a/embedchain/poetry.lock
+++ b/embedchain/poetry.lock
@@ -2644,6 +2644,7 @@ description = "Client library to connect to the LangSmith LLM Tracing and Evalua
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
+
{file = "langsmith-0.1.126-py3-none-any.whl", hash = "sha256:16c38ba5dae37a3cc715b6bc5d87d9579228433c2f34d6fa328345ee2b2bcc2a"},
{file = "langsmith-0.1.126.tar.gz", hash = "sha256:40f72e2d1d975473dd69269996053122941c1252915bcea55787607e2a7f949a"},
]
@@ -6750,3 +6751,4 @@ weaviate = ["weaviate-client"]
lock-version = "2.0"
python-versions = ">=3.9,<=3.13"
content-hash = "ec8a87e5281b7fa0c2c28f24c2562e823f0c546a24da2bb285b2f239b7b1758d"
+
diff --git a/embedchain/tests/loaders/test_web_page.py b/embedchain/tests/loaders/test_web_page.py
index 3134d4d0c1..46036ee200 100644
--- a/embedchain/tests/loaders/test_web_page.py
+++ b/embedchain/tests/loaders/test_web_page.py
@@ -2,6 +2,7 @@
from unittest.mock import Mock, patch
import pytest
+import requests
from embedchain.loaders.web_page import WebPageLoader
@@ -115,3 +116,33 @@ def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
assert class_name not in content
assert len(content) > 0
+
+
+def test_fetch_reference_links_success(web_page_loader):
+ # Mock a successful response
+ response = Mock(spec=requests.Response)
+ response.status_code = 200
+ response.content = b"""
+
+
+ Example
+ Another Example
+ Relative Link
+
+
+ """
+
+ expected_links = ["http://example.com", "https://another-example.com"]
+ result = web_page_loader.fetch_reference_links(response)
+ assert result == expected_links
+
+
+def test_fetch_reference_links_failure(web_page_loader):
+ # Mock a failed response
+ response = Mock(spec=requests.Response)
+ response.status_code = 404
+ response.content = b""
+
+ expected_links = []
+ result = web_page_loader.fetch_reference_links(response)
+ assert result == expected_links