Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FastEmbed Embeddings #2083

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
40 changes: 40 additions & 0 deletions embedchain/embedchain/embedder/fastembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List, Optional, Sequence, Union

try:
from fastembed import TextEmbedding
except ImportError:
raise ValueError("The 'fastembed' package is not installed. Please install it with `pip install fastembed`")

from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions

Embedding = Sequence[float]
Embeddings = List[Embedding]


class FastEmbedEmbedder(BaseEmbedder):
"""
Generate embeddings using FastEmbed - https://qdrant.github.io/fastembed/.
Find the list of supported models at https://qdrant.github.io/fastembed/examples/Supported_Models/.
"""
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

self.config.model = self.config.model or "BAAI/bge-small-en-v1.5"

embedding_fn = FastEmbedEmbeddingFunction(config=self.config)
self.set_embedding_fn(embedding_fn=embedding_fn)

vector_dimension = self.config.vector_dimension or VectorDimensions.FASTEMBED.value
self.set_vector_dimension(vector_dimension=vector_dimension)


class FastEmbedEmbeddingFunction:
def __init__(self, config: BaseEmbedderConfig) -> None:
self.config = config
self._model = TextEmbedding(model_name=self.config.model, **self.config.model_kwargs)

def __call__(self, input: Union[list[str], str]) -> List[Embedding]:
embeddings = self._model.embed(input)
return [embedding.tolist() for embedding in embeddings]
2 changes: 2 additions & 0 deletions embedchain/embedchain/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class EmbedderFactory:
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
"fastembed": "embedchain.embedder.fastembed.FastEmbedEmbedder",
"aws_bedrock": "embedchain.embedder.aws_bedrock.AWSBedrockEmbedder",
}
provider_to_config_class = {
Expand All @@ -71,6 +72,7 @@ class EmbedderFactory:
"clarifai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
"fastembed": "embedchain.config.embedder.base.BaseEmbedderConfig",
"aws_bedrock": "embedchain.config.embedder.aws_bedrock.AWSBedrockEmbedderConfig",
}

Expand Down
1 change: 1 addition & 0 deletions embedchain/embedchain/models/vector_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ class VectorDimensions(Enum):
NVIDIA_AI = 1024
COHERE = 384
OLLAMA = 384
FASTEMBED = 384
AMAZON_TITAN_V1 = 1536
AMAZON_TITAN_V2 = 1024
170 changes: 168 additions & 2 deletions embedchain/poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions embedchain/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ langchain-cohere = "^0.3.0"
langchain-community = "^0.3.1"
langchain-aws = {version = "^0.2.1", optional = true}
langsmith = "^0.1.17"
fastembed = { version = "^0.4.2", optional = true, python = "<3.13,>=3.8.0" }

[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
Expand Down Expand Up @@ -179,6 +180,7 @@ postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
mysql = ["mysql-connector-python"]
google = ["google-generativeai"]
mistralai = ["langchain-mistralai"]
fastembed = ["fastembed"]
aws = ["langchain-aws"]

[tool.poetry.group.docs.dependencies]
Expand Down
19 changes: 19 additions & 0 deletions embedchain/tests/embedder/test_fastembed_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

from unittest.mock import patch

from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.fastembed import FastEmbedEmbedder


def test_fastembed_embedder_with_model(monkeypatch):
model = "intfloat/multilingual-e5-large"
model_kwargs = {"threads": 5}
config = BaseEmbedderConfig(model=model, model_kwargs=model_kwargs)
with patch('embedchain.embedder.fastembed.TextEmbedding') as mock_embeddings:
embedder = FastEmbedEmbedder(config=config)
assert embedder.config.model == model
assert embedder.config.model_kwargs == model_kwargs
mock_embeddings.assert_called_once_with(
model_name=model,
threads=5
)
4 changes: 2 additions & 2 deletions mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

class EmbedderConfig(BaseModel):
provider: str = Field(
description="Provider of the embedding model (e.g., 'ollama', 'openai')",
description="Provider of the embedding model (e.g., 'ollama', 'openai', 'fastembed')",
default="openai",
)
config: Optional[dict] = Field(description="Configuration for the specific embedding model", default={})

@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together"]:
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "fastembed"]:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
26 changes: 26 additions & 0 deletions mem0/embeddings/fastembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase

try:
from fastembed import TextEmbedding
except ImportError as e:
raise ImportError(
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`"
) from e


class FastEmbedEmbedding(EmbeddingBase):
"""
Generate embeddings vector embeddings using FastEmbed - https://qdrant.github.io/fastembed/.
Find the list of supported models at https://qdrant.github.io/fastembed/examples/Supported_Models/.
"""

def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

self._model = TextEmbedding(model_name=self.config.model, **self.config.model_kwargs)

def embed(self, text):
return next(self._model.embed(text)).tolist()
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EmbedderFactory:
"openai": "mem0.embeddings.openai.OpenAIEmbedding",
"ollama": "mem0.embeddings.ollama.OllamaEmbedding",
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"fastembed": "mem0.embeddings.fastembed.FastEmbedEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
Expand Down