From 5a0286967c819f27246fb4add3497fb661c932f6 Mon Sep 17 00:00:00 2001 From: Valery Kharitonov Date: Tue, 28 May 2024 22:20:05 -0400 Subject: [PATCH] Move provider definitions to separate package --- gptcli/assistant.py | 8 ++++---- gptcli/config.py | 2 +- gptcli/gpt.py | 10 +++++----- gptcli/providers/__init__.py | 0 gptcli/{ => providers}/anthropic.py | 0 gptcli/{ => providers}/cohere.py | 0 gptcli/{ => providers}/llama.py | 0 gptcli/{ => providers}/openai.py | 0 gptcli/shell.py | 7 +++++-- 9 files changed, 15 insertions(+), 12 deletions(-) create mode 100644 gptcli/providers/__init__.py rename gptcli/{ => providers}/anthropic.py (100%) rename gptcli/{ => providers}/cohere.py (100%) rename gptcli/{ => providers}/llama.py (100%) rename gptcli/{ => providers}/openai.py (100%) diff --git a/gptcli/assistant.py b/gptcli/assistant.py index 32ee895..47fdc02 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -3,7 +3,6 @@ from attr import dataclass import platform from typing import Any, Dict, Iterator, Optional, TypedDict, List -from gptcli.cohere import CohereCompletionProvider from gptcli.completion import ( CompletionEvent, @@ -11,9 +10,10 @@ ModelOverrides, Message, ) -from gptcli.llama import LLaMACompletionProvider -from gptcli.openai import OpenAICompletionProvider -from gptcli.anthropic import AnthropicCompletionProvider +from gptcli.providers.llama import LLaMACompletionProvider +from gptcli.providers.openai import OpenAICompletionProvider +from gptcli.providers.anthropic import AnthropicCompletionProvider +from gptcli.providers.cohere import CohereCompletionProvider class AssistantConfig(TypedDict, total=False): diff --git a/gptcli/config.py b/gptcli/config.py index 6111128..b6c83cd 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -4,7 +4,7 @@ import yaml from gptcli.assistant import AssistantConfig -from gptcli.llama import LLaMAModelConfig +from gptcli.providers.llama import LLaMAModelConfig CONFIG_FILE_PATHS = [ diff --git a/gptcli/gpt.py b/gptcli/gpt.py index 06045f6..232d9d5 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -13,8 +13,8 @@ import sys import logging import datetime -import gptcli.anthropic -import gptcli.cohere +import gptcli.providers.anthropic +import gptcli.providers.cohere from gptcli.assistant import ( Assistant, DEFAULT_ASSISTANTS, @@ -32,7 +32,7 @@ choose_config_file, read_yaml_config, ) -from gptcli.llama import init_llama_models +from gptcli.providers.llama import init_llama_models from gptcli.logging import LoggingChatListener from gptcli.cost import PriceChatListener from gptcli.session import ChatSession @@ -184,10 +184,10 @@ def main(): openai.api_key = config.openai_api_key if config.anthropic_api_key: - gptcli.anthropic.api_key = config.anthropic_api_key + gptcli.providers.anthropic.api_key = config.anthropic_api_key if config.cohere_api_key: - gptcli.cohere.api_key = config.cohere_api_key + gptcli.providers.cohere.api_key = config.cohere_api_key if config.llama_models is not None: init_llama_models(config.llama_models) diff --git a/gptcli/providers/__init__.py b/gptcli/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gptcli/anthropic.py b/gptcli/providers/anthropic.py similarity index 100% rename from gptcli/anthropic.py rename to gptcli/providers/anthropic.py diff --git a/gptcli/cohere.py b/gptcli/providers/cohere.py similarity index 100% rename from gptcli/cohere.py rename to gptcli/providers/cohere.py diff --git a/gptcli/llama.py b/gptcli/providers/llama.py similarity index 100% rename from gptcli/llama.py rename to gptcli/providers/llama.py diff --git a/gptcli/openai.py b/gptcli/providers/openai.py similarity index 100% rename from gptcli/openai.py rename to gptcli/providers/openai.py diff --git a/gptcli/shell.py b/gptcli/shell.py index 83bc681..4005714 100644 --- a/gptcli/shell.py +++ b/gptcli/shell.py @@ -14,8 +14,9 @@ def simple_response(assistant: Assistant, prompt: str, stream: bool) -> None: result = "" try: for response in response_iter: - result += response - sys.stdout.write(response) + if response.type == "message_delta": + result += response.text + sys.stdout.write(response.text) except KeyboardInterrupt: pass finally: @@ -29,6 +30,8 @@ def execute(assistant: Assistant, prompt: str) -> None: logging.info("User: %s", prompt) response_iter = assistant.complete_chat(messages, stream=False) result = next(response_iter) + assert result.type == "message_delta" + result = result.text logging.info("Assistant: %s", result) with tempfile.NamedTemporaryFile(mode="w", prefix="gptcli-", delete=False) as f: