diff --git a/README.md b/README.md index 55d1316..63e8bb5 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,8 @@ OPENAI_API_KEY=$TOGETHER_API_KEY OPENAI_BASE_URL=https://api.together.xyz/v1 gpt The prefix is stripped before sending the request to the API. +Similarly, use the `oai-azure:` model name prefix to use a model deployed via Azure Open AI. For example, `oai-azure:my-deployment-name`. + ## Other chat bots ### Anthropic Claude diff --git a/gptcli/assistant.py b/gptcli/assistant.py index d6932a0..82d44d1 100644 --- a/gptcli/assistant.py +++ b/gptcli/assistant.py @@ -15,6 +15,7 @@ from gptcli.providers.openai import OpenAICompletionProvider from gptcli.providers.anthropic import AnthropicCompletionProvider from gptcli.providers.cohere import CohereCompletionProvider +from gptcli.providers.azure_openai import AzureOpenAICompletionProvider class AssistantConfig(TypedDict, total=False): @@ -75,6 +76,8 @@ def get_completion_provider(model: str) -> CompletionProvider: or model.startswith("o1") ): return OpenAICompletionProvider() + elif model.startswith("oai-azure:"): + return AzureOpenAICompletionProvider() elif model.startswith("claude"): return AnthropicCompletionProvider() elif model.startswith("llama"): diff --git a/gptcli/config.py b/gptcli/config.py index a23753b..ce23e07 100644 --- a/gptcli/config.py +++ b/gptcli/config.py @@ -21,6 +21,7 @@ class GptCliConfig: api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL") + openai_azure_api_version: str = "2024-10-21" anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY") google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY") cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY") diff --git a/gptcli/gpt.py b/gptcli/gpt.py index 897694e..a5b0260 100755 --- a/gptcli/gpt.py +++ b/gptcli/gpt.py @@ -179,6 +179,9 @@ def main(): if config.openai_base_url: openai.base_url = config.openai_base_url + if config.openai_azure_api_version: + openai.api_version = config.openai_azure_api_version + if config.api_key: openai.api_key = config.api_key elif config.openai_api_key: diff --git a/gptcli/providers/azure_openai.py b/gptcli/providers/azure_openai.py new file mode 100644 index 0000000..12648ac --- /dev/null +++ b/gptcli/providers/azure_openai.py @@ -0,0 +1,13 @@ +import openai +from openai import AzureOpenAI +from gptcli.providers.openai import OpenAICompletionProvider + + +class AzureOpenAICompletionProvider(OpenAICompletionProvider): + def __init__(self): + super().__init__() + self.client = AzureOpenAI( + api_key=openai.api_key, + base_url=openai.base_url, + api_version=openai.api_version, + ) diff --git a/gptcli/providers/openai.py b/gptcli/providers/openai.py index 97574fa..14a8035 100644 --- a/gptcli/providers/openai.py +++ b/gptcli/providers/openai.py @@ -1,7 +1,7 @@ import re from typing import Iterator, List, Optional, cast import openai -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletionMessageParam from gptcli.completion import ( @@ -33,6 +33,9 @@ def complete( if model.startswith("oai-compat:"): model = model[len("oai-compat:") :] + if model.startswith("oai-azure:"): + model = model[len("oai-azure:") :] + try: if stream: response_iter = self.client.chat.completions.create(