diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index dadda83a23..99fc6a4fd3 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -21,7 +21,8 @@ from embedchain.helper.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) from embedchain.utils import detect_datatype from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index d21cfb6b7b..5e3191a637 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -1,6 +1,7 @@ from typing import Optional -import openai +from langchain.chat_models import ChatOpenAI +from langchain.schema import HumanMessage, SystemMessage from embedchain.config import BaseLlmConfig from embedchain.helper.json_serializable import register_deserializable @@ -12,31 +13,32 @@ class OpenAILlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config=config) - # NOTE: This class does not use langchain. One reason is that `top_p` is not supported. - def get_llm_model_answer(self, prompt): - messages = [] - if self.config.system_prompt: - messages.append({"role": "system", "content": self.config.system_prompt}) - messages.append({"role": "user", "content": prompt}) - response = openai.ChatCompletion.create( - model=self.config.model or "gpt-3.5-turbo-0613", - messages=messages, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - stream=self.config.stream, - ) + response = OpenAILlm._get_answer(prompt, self.config) if self.config.stream: - return self._stream_llm_model_response(response) + return response + else: + return response.content + + def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + messages = [] + if config.system_prompt: + messages.append(SystemMessage(content=config.system_prompt)) + messages.append(HumanMessage(content=prompt)) + kwargs = { + "model": config.model or "gpt-3.5-turbo-0613", + "temperature": config.temperature, + "max_tokens": config.max_tokens, + "model_kwargs": {}, + } + if config.top_p: + kwargs["model_kwargs"]["top_p"] = config.top_p + if config.stream: + from langchain.callbacks.streaming_stdout import \ + StreamingStdOutCallbackHandler + + chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) else: - return response["choices"][0]["message"]["content"] - - def _stream_llm_model_response(self, response): - """ - This is a generator for streaming response from the OpenAI completions API - """ - for line in response: - chunk = line["choices"][0].get("delta", {}).get("content", "") - yield chunk + chat = ChatOpenAI(**kwargs) + return chat(messages) diff --git a/tests/llm/test_query.py b/tests/llm/test_query.py index f916167fb2..625ced4137 100644 --- a/tests/llm/test_query.py +++ b/tests/llm/test_query.py @@ -46,41 +46,29 @@ def test_query(self): self.assertEqual(input_query_arg, "Test query") mock_answer.assert_called_once() - @patch("openai.ChatCompletion.create") - def test_query_config_app_passing(self, mock_create): - mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response + @patch("embedchain.llm.openai.OpenAILlm._get_answer") + def test_query_config_app_passing(self, mock_get_answer): + mock_get_answer.return_value = MagicMock() + mock_get_answer.return_value.content = "Test answer" config = AppConfig(collect_metrics=False) chat_config = BaseLlmConfig(system_prompt="Test system prompt") app = App(config=config, llm_config=chat_config) + answer = app.llm.get_llm_model_answer("Test query") - app.llm.get_llm_model_answer("Test query") - - # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument - messages_arg = mock_create.call_args.kwargs["messages"] - self.assertTrue(messages_arg[0].get("role"), "system") - self.assertEqual(messages_arg[0].get("content"), "Test system prompt") - self.assertTrue(messages_arg[1].get("role"), "user") - self.assertEqual(messages_arg[1].get("content"), "Test query") - - # TODO: Add tests for other config variables - - @patch("openai.ChatCompletion.create") - def test_app_passing(self, mock_create): - mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response + self.assertEqual(app.llm.config.system_prompt, "Test system prompt") + self.assertEqual(answer, "Test answer") + @patch("embedchain.llm.openai.OpenAILlm._get_answer") + def test_app_passing(self, mock_get_answer): + mock_get_answer.return_value = MagicMock() + mock_get_answer.return_value.content = "Test answer" config = AppConfig(collect_metrics=False) chat_config = BaseLlmConfig() app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt") - + answer = app.llm.get_llm_model_answer("Test query") self.assertEqual(app.llm.config.system_prompt, "Test system prompt") - - app.llm.get_llm_model_answer("Test query") - - # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument - messages_arg = mock_create.call_args.kwargs["messages"] - self.assertTrue(messages_arg[0].get("role"), "system") - self.assertEqual(messages_arg[0].get("content"), "Test system prompt") + self.assertEqual(answer, "Test answer") @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_query_with_where_in_params(self):