Skip to content

Commit

Permalink
feat(openai): add azure support (#96)
Browse files Browse the repository at this point in the history
Adds support for using Azure OpenAI endpoints
  • Loading branch information
trojan-bumble-bee authored Nov 19, 2024
1 parent 97fd19b commit d61cbeb
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ OPENAI_API_KEY=$TOGETHER_API_KEY OPENAI_BASE_URL=https://api.together.xyz/v1 gpt

The prefix is stripped before sending the request to the API.

Similarly, use the `oai-azure:` model name prefix to use a model deployed via Azure Open AI. For example, `oai-azure:my-deployment-name`.

## Other chat bots

### Anthropic Claude
Expand Down
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gptcli.providers.openai import OpenAICompletionProvider
from gptcli.providers.anthropic import AnthropicCompletionProvider
from gptcli.providers.cohere import CohereCompletionProvider
from gptcli.providers.azure_openai import AzureOpenAICompletionProvider


class AssistantConfig(TypedDict, total=False):
Expand Down Expand Up @@ -74,6 +75,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
or model.startswith("o1")
):
return OpenAICompletionProvider()
elif model.startswith("oai-azure:"):
return AzureOpenAICompletionProvider()
elif model.startswith("claude"):
return AnthropicCompletionProvider()
elif model.startswith("llama"):
Expand Down
1 change: 1 addition & 0 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class GptCliConfig:
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL")
openai_azure_api_version: str = "2024-10-21"
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY")
Expand Down
3 changes: 3 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def main():
if config.openai_base_url:
openai.base_url = config.openai_base_url

if config.openai_azure_api_version:
openai.api_version = config.openai_azure_api_version

if config.api_key:
openai.api_key = config.api_key
elif config.openai_api_key:
Expand Down
13 changes: 13 additions & 0 deletions gptcli/providers/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import openai
from openai import AzureOpenAI
from gptcli.providers.openai import OpenAICompletionProvider


class AzureOpenAICompletionProvider(OpenAICompletionProvider):
def __init__(self):
super().__init__()
self.client = AzureOpenAI(
api_key=openai.api_key,
base_url=openai.base_url,
api_version=openai.api_version,
)
5 changes: 4 additions & 1 deletion gptcli/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Iterator, List, Optional, cast
import openai
from openai import OpenAI
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletionMessageParam

from gptcli.completion import (
Expand Down Expand Up @@ -33,6 +33,9 @@ def complete(
if model.startswith("oai-compat:"):
model = model[len("oai-compat:") :]

if model.startswith("oai-azure:"):
model = model[len("oai-azure:") :]

try:
if stream:
response_iter = self.client.chat.completions.create(
Expand Down

0 comments on commit d61cbeb

Please sign in to comment.