From 08068ed16a0c07bc6e1e0c210ba1052dcb7b2e40 Mon Sep 17 00:00:00 2001 From: Qihang <35721413+Mark-Zeng@users.noreply.github.com> Date: Wed, 6 Sep 2023 22:15:13 -0700 Subject: [PATCH 1/4] Add non-GCP PALM providers --- docs/advanced/app_types.mdx | 2 ++ embedchain/embedder/palm_embedder.py | 26 +++++++++++++++++++++ embedchain/llm/palm_llm.py | 31 +++++++++++++++++++++++++ embedchain/models/EmbeddingFunctions.py | 1 + pyproject.toml | 1 + 5 files changed, 61 insertions(+) create mode 100644 embedchain/embedder/palm_embedder.py create mode 100644 embedchain/llm/palm_llm.py diff --git a/docs/advanced/app_types.mdx b/docs/advanced/app_types.mdx index 147d769f05..b5ad43fc94 100644 --- a/docs/advanced/app_types.mdx +++ b/docs/advanced/app_types.mdx @@ -99,12 +99,14 @@ app = CustomApp( - GPT4ALL - AZURE_OPENAI - LLAMA2 + - PALM - Following embedding functions are available for an embedding function - OPENAI - HUGGING_FACE - VERTEX_AI - GPT4ALL - AZURE_OPENAI + - PALM ### PersonApp diff --git a/embedchain/embedder/palm_embedder.py b/embedchain/embedder/palm_embedder.py new file mode 100644 index 0000000000..703700c649 --- /dev/null +++ b/embedchain/embedder/palm_embedder.py @@ -0,0 +1,26 @@ +import os +from typing import Optional + +from langchain.embeddings import GooglePalmEmbeddings + +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.models import EmbeddingFunctions + + +class PalmEmbedder(BaseEmbedder): + def __init__(self, config: Optional[BaseEmbedderConfig] = None): + super().__init__(config=config) + + if os.getenv("GOOGLE_API_KEY") is None: + raise ValueError("GOOGLE_API_KEY environment variables not provided") + + model = "models/embedding-gecko-001" + if (config is not None) and (config.model is not None): + model = config.model + + embeddings = GooglePalmEmbeddings(model_name=model) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) + + self.set_embedding_fn(embedding_fn=embedding_fn) + self.set_vector_dimension(vector_dimension=EmbeddingFunctions.PALM.value) diff --git a/embedchain/llm/palm_llm.py b/embedchain/llm/palm_llm.py new file mode 100644 index 0000000000..0c34320b4e --- /dev/null +++ b/embedchain/llm/palm_llm.py @@ -0,0 +1,31 @@ +import os +from typing import Optional + +from embedchain.config import BaseLlmConfig +from embedchain.helper_classes.json_serializable import register_deserializable +from embedchain.llm.base_llm import BaseLlm + + +@register_deserializable +class PalmLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + return PalmLlm._get_athrophic_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str: + if os.getenv("GOOGLE_API_KEY") is None: + raise ValueError("GOOGLE_API_KEY environment variables not provided") + + from langchain.chat_models import ChatGooglePalm + + model = "models/chat-bison-001" + if (config is not None) and (config.model is not None): + model = config.model + + chat = ChatGooglePalm(temperature=config.temperature, top_p=config.top_p, model_name=model) + messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) + + return chat(messages).content diff --git a/embedchain/models/EmbeddingFunctions.py b/embedchain/models/EmbeddingFunctions.py index 7967c45a22..2b0075ad74 100644 --- a/embedchain/models/EmbeddingFunctions.py +++ b/embedchain/models/EmbeddingFunctions.py @@ -6,3 +6,4 @@ class EmbeddingFunctions(Enum): HUGGING_FACE = "HUGGING_FACE" VERTEX_AI = "VERTEX_AI" GPT4ALL = "GPT4ALL" + PALM = "PALM" diff --git a/pyproject.toml b/pyproject.toml index 1f3a07c2fa..3d613142be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ youtube-transcript-api = "^0.6.1" beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" pytube = "^15.0.0" +google-generativeai = "^0.1.0" llama-index = { version = "^0.7.21", optional = true } sentence-transformers = { version = "^2.2.2", optional = true } torch = { version = ">=2.0.0, !=2.0.1", optional = true } From 389ac02558e20e24b97f50f864daa8a259e04a2c Mon Sep 17 00:00:00 2001 From: Qihang <35721413+Mark-Zeng@users.noreply.github.com> Date: Mon, 11 Sep 2023 19:35:49 -0700 Subject: [PATCH 2/4] remove indirect dependency --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cbde310276..afa1113a78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,6 @@ youtube-transcript-api = "^0.6.1" beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" pytube = "^15.0.0" -google-generativeai = "^0.1.0" llama-index = { version = "^0.7.21", optional = true } sentence-transformers = { version = "^2.2.2", optional = true } torch = { version = ">=2.0.0, !=2.0.1", optional = true } From 5a0eb9c3f98340ae0de70bb853cfd2f16692f1ed Mon Sep 17 00:00:00 2001 From: Qihang <35721413+Mark-Zeng@users.noreply.github.com> Date: Wed, 13 Sep 2023 00:37:34 -0700 Subject: [PATCH 3/4] Update PALM dependency and add module installment checks --- embedchain/embedder/{palm_embedder.py => palm.py} | 7 ++++++- embedchain/llm/{palm_llm.py => palm.py} | 9 +++++++-- pyproject.toml | 3 ++- 3 files changed, 15 insertions(+), 4 deletions(-) rename embedchain/embedder/{palm_embedder.py => palm.py} (71%) rename embedchain/llm/{palm_llm.py => palm.py} (71%) diff --git a/embedchain/embedder/palm_embedder.py b/embedchain/embedder/palm.py similarity index 71% rename from embedchain/embedder/palm_embedder.py rename to embedchain/embedder/palm.py index 703700c649..bf7f3ed37a 100644 --- a/embedchain/embedder/palm_embedder.py +++ b/embedchain/embedder/palm.py @@ -2,14 +2,19 @@ from typing import Optional from langchain.embeddings import GooglePalmEmbeddings +from importlib.util import find_spec from embedchain.config import BaseEmbedderConfig -from embedchain.embedder.base_embedder import BaseEmbedder +from embedchain.embedder.base import BaseEmbedder from embedchain.models import EmbeddingFunctions class PalmEmbedder(BaseEmbedder): def __init__(self, config: Optional[BaseEmbedderConfig] = None): + if find_spec("google.generativeai") is None: + raise ModuleNotFoundError( + "The google-generativeai python package is not installed. Please install it with `pip install --upgrade embedchain[palm]`" # noqa E501 + ) super().__init__(config=config) if os.getenv("GOOGLE_API_KEY") is None: diff --git a/embedchain/llm/palm_llm.py b/embedchain/llm/palm.py similarity index 71% rename from embedchain/llm/palm_llm.py rename to embedchain/llm/palm.py index 0c34320b4e..eb1a2528d9 100644 --- a/embedchain/llm/palm_llm.py +++ b/embedchain/llm/palm.py @@ -1,14 +1,19 @@ import os +from importlib.util import find_spec from typing import Optional from embedchain.config import BaseLlmConfig -from embedchain.helper_classes.json_serializable import register_deserializable -from embedchain.llm.base_llm import BaseLlm +from embedchain.helper.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm @register_deserializable class PalmLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): + if find_spec("google.generativeai") is None: + raise ModuleNotFoundError( + "The google-generativeai python package is not installed. Please install it with `pip install --upgrade embedchain[palm]`" # noqa E501 + ) super().__init__(config=config) def get_llm_model_answer(self, prompt): diff --git a/pyproject.toml b/pyproject.toml index afa1113a78..4cb389acb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,7 +103,7 @@ twilio = { version = "^8.5.0", optional = true } fastapi-poe = { version = "0.0.16", optional = true } discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } - +google-generativeai = { version = "^0.1.0", optional = true} [tool.poetry.group.dev.dependencies] @@ -120,6 +120,7 @@ isort = "^5.12.0" streamlit = ["streamlit"] community = ["llama-index"] opensource = ["sentence-transformers", "torch", "gpt4all"] +palm = ["google-generativeai"] elasticsearch = ["elasticsearch"] poe = ["fastapi-poe"] discord = ["discord"] From 90cdae47b16de42c71bb09679f5f2fc6490f549a Mon Sep 17 00:00:00 2001 From: Qihang <35721413+Mark-Zeng@users.noreply.github.com> Date: Sun, 17 Sep 2023 21:49:24 -0700 Subject: [PATCH 4/4] simplify logic --- embedchain/embedder/palm.py | 7 +++---- embedchain/llm/palm.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/embedchain/embedder/palm.py b/embedchain/embedder/palm.py index bf7f3ed37a..853a6512ba 100644 --- a/embedchain/embedder/palm.py +++ b/embedchain/embedder/palm.py @@ -20,11 +20,10 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): if os.getenv("GOOGLE_API_KEY") is None: raise ValueError("GOOGLE_API_KEY environment variables not provided") - model = "models/embedding-gecko-001" - if (config is not None) and (config.model is not None): - model = config.model + if not self.config.model: + self.config.model = "models/embedding-gecko-001" - embeddings = GooglePalmEmbeddings(model_name=model) + embeddings = GooglePalmEmbeddings(model_name=self.config.model) embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) self.set_embedding_fn(embedding_fn=embedding_fn) diff --git a/embedchain/llm/palm.py b/embedchain/llm/palm.py index eb1a2528d9..a80ed01c98 100644 --- a/embedchain/llm/palm.py +++ b/embedchain/llm/palm.py @@ -26,11 +26,10 @@ def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str: from langchain.chat_models import ChatGooglePalm - model = "models/chat-bison-001" - if (config is not None) and (config.model is not None): - model = config.model + if not config.model: + config.model = "models/chat-bison-001" - chat = ChatGooglePalm(temperature=config.temperature, top_p=config.top_p, model_name=model) + chat = ChatGooglePalm(temperature=config.temperature, top_p=config.top_p, model_name=config.model) messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) return chat(messages).content