Skip to content

Commit

Permalink
Adding fetching data functionality for reference links in the web page (
Browse files Browse the repository at this point in the history
  • Loading branch information
vatsalrathod16 authored Oct 15, 2024
1 parent 721d765 commit 20c3aee
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 8 deletions.
3 changes: 3 additions & 0 deletions embedchain/docs/api-reference/app/add.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ title: '📊 add'
<ParamField path="metadata" type="dict" optional>
Any metadata that you want to store with the data source. Metadata is generally really useful for doing metadata filtering on top of semantic search to yield faster search and better results.
</ParamField>
<ParamField path="all_references" type="bool" optional>
This parameter instructs Embedchain to retrieve all the context and information from the specified link, as well as from any reference links on the page.
</ParamField>

## Usage

Expand Down
13 changes: 10 additions & 3 deletions embedchain/embedchain/chunkers/base_chunker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib
import logging
from typing import Optional
from typing import Any, Optional

from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable
Expand All @@ -15,7 +15,14 @@ def __init__(self, text_splitter):
self.text_splitter = text_splitter
self.data_type = None

def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
def create_chunks(
self,
loader,
src,
app_id=None,
config: Optional[ChunkerConfig] = None,
**kwargs: Optional[dict[str, Any]],
):
"""
Loads data and chunks it.
Expand All @@ -30,7 +37,7 @@ def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig
id_map = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src)
data_result = loader.load_data(src, **kwargs)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
# Prefix app_id in the document id if app_id is not None to
Expand Down
10 changes: 8 additions & 2 deletions embedchain/embedchain/data_formatter/data_formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from importlib import import_module
from typing import Optional
from typing import Any, Optional

from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
Expand Down Expand Up @@ -40,7 +40,13 @@ def _lazy_load(module_path: str):
module = import_module(module_path)
return getattr(module, class_name)

def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
def _get_loader(
self,
data_type: DataType,
config: LoaderConfig,
loader: Optional[BaseLoader],
**kwargs: Optional[dict[str, Any]],
) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
Expand Down
2 changes: 1 addition & 1 deletion embedchain/embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _load_and_embed(
app_id = self.config.id if self.config is not None else None

# Create chunks
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker, **kwargs)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
Expand Down
4 changes: 3 additions & 1 deletion embedchain/embedchain/loaders/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Optional

from embedchain.helpers.json_serializable import JSONSerializable


class BaseLoader(JSONSerializable):
def __init__(self):
pass

def load_data(self, url):
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
"""
Implemented by child classes
"""
Expand Down
28 changes: 27 additions & 1 deletion embedchain/embedchain/loaders/web_page.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import logging
from typing import Any, Optional

import requests

Expand All @@ -22,14 +23,29 @@ class WebPageLoader(BaseLoader):
# Shared session for all instances
_session = requests.Session()

def load_data(self, url):
def load_data(self, url, **kwargs: Optional[dict[str, Any]]):
"""Load data from a web page using a shared requests' session."""
all_references = False
for key, value in kwargs.items():
if key == "all_references":
all_references = kwargs["all_references"]
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
}
response = self._session.get(url, headers=headers, timeout=30)
response.raise_for_status()
data = response.content
reference_links = self.fetch_reference_links(response)
if all_references:
for i in reference_links:
try:
response = self._session.get(i, headers=headers, timeout=30)
response.raise_for_status()
data += response.content
except Exception as e:
logging.error(f"Failed to add URL {url}: {e}")
continue

content = self._get_clean_content(data, url)

metadata = {"url": url}
Expand Down Expand Up @@ -98,3 +114,13 @@ def _get_clean_content(html, url) -> str:
@classmethod
def close_session(cls):
cls._session.close()

def fetch_reference_links(self, response):
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
a_tags = soup.find_all("a", href=True)
reference_links = [a["href"] for a in a_tags if a["href"].startswith("http")]
return reference_links
else:
print(f"Failed to retrieve the page. Status code: {response.status_code}")
return []
1 change: 1 addition & 0 deletions embedchain/embedchain/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def add(
documents: list[str],
metadatas: list[object],
ids: list[str],
**kwargs: Optional[dict[str, Any]],
) -> Any:
"""
Add vectors to chroma database
Expand Down
2 changes: 2 additions & 0 deletions embedchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions embedchain/tests/loaders/test_web_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock, patch

import pytest
import requests

from embedchain.loaders.web_page import WebPageLoader

Expand Down Expand Up @@ -115,3 +116,33 @@ def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
assert class_name not in content

assert len(content) > 0


def test_fetch_reference_links_success(web_page_loader):
# Mock a successful response
response = Mock(spec=requests.Response)
response.status_code = 200
response.content = b"""
<html>
<body>
<a href="http://example.com">Example</a>
<a href="https://another-example.com">Another Example</a>
<a href="/relative-link">Relative Link</a>
</body>
</html>
"""

expected_links = ["http://example.com", "https://another-example.com"]
result = web_page_loader.fetch_reference_links(response)
assert result == expected_links


def test_fetch_reference_links_failure(web_page_loader):
# Mock a failed response
response = Mock(spec=requests.Response)
response.status_code = 404
response.content = b""

expected_links = []
result = web_page_loader.fetch_reference_links(response)
assert result == expected_links

0 comments on commit 20c3aee

Please sign in to comment.