Skip to content

Commit

Permalink
Python: Support DefaultAzureCredential for AzureAI Inference (#8862)
Browse files Browse the repository at this point in the history
### Motivation and Context

In recent work, the DefaultAzureCredential path was added for
AzureOpenAI, but was missed for AzureAI inference.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

Adding support for default azure credential for AzureAI inference and
its tests.

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Sep 17, 2024
1 parent 35f9c52 commit 40e4c1c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest
- name: Azure CLI Login
if: github.event_name != 'pull_request' && matrix.integration-tests
if: github.event_name != 'pull_request'
uses: azure/login@v2
with:
client-id: ${{ secrets.AZURE_CLIENT_ID }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from pydantic import ValidationError

from semantic_kernel.connectors.ai.azure_ai_inference import (
Expand All @@ -34,6 +35,7 @@
from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.function_calling_utils import update_settings_from_function_call_configuration
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
Expand Down Expand Up @@ -106,11 +108,24 @@ def __init__(
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Azure AI Inference settings: {e}") from e

client = ChatCompletionsClient(
endpoint=str(azure_ai_inference_settings.endpoint),
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
endpoint_to_use: str = str(azure_ai_inference_settings.endpoint)
if azure_ai_inference_settings.api_key is not None:
client = ChatCompletionsClient(
endpoint=endpoint_to_use,
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
else:
# Try to create the client with a DefaultAzureCredential
client = (
ChatCompletionsClient(
endpoint=endpoint_to_use,
credential=DefaultAzureCredential(),
credential_scopes=["https://cognitiveservices.azure.com/.default"],
api_version=DEFAULT_AZURE_API_VERSION,
user_agent=SEMANTIC_KERNEL_USER_AGENT,
),
)

super().__init__(
ai_model_id=ai_model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
else:
from typing_extensions import override # pragma: no cover

from azure.identity import DefaultAzureCredential

from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_prompt_execution_settings import (
AzureAIInferenceEmbeddingPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.azure_ai_inference.azure_ai_inference_settings import AzureAIInferenceSettings
from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_base import AzureAIInferenceBase
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT
Expand Down Expand Up @@ -72,11 +75,22 @@ def __init__(
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Azure AI Inference settings: {e}") from e

client = EmbeddingsClient(
endpoint=str(azure_ai_inference_settings.endpoint),
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
endpoint = str(azure_ai_inference_settings.endpoint)
if azure_ai_inference_settings.api_key is not None:
client = EmbeddingsClient(
endpoint=endpoint,
credential=AzureKeyCredential(azure_ai_inference_settings.api_key.get_secret_value()),
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)
else:
# Try to create the client with a DefaultAzureCredential
client = EmbeddingsClient(
endpoint=endpoint,
credential=DefaultAzureCredential(),
credential_scopes=["https://cognitiveservices.azure.com/.default"],
api_version=DEFAULT_AZURE_API_VERSION,
user_agent=SEMANTIC_KERNEL_USER_AGENT,
)

super().__init__(
ai_model_id=ai_model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import pytest
from azure.ai.inference.aio import EmbeddingsClient
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential

from semantic_kernel.connectors.ai.azure_ai_inference.services.azure_ai_inference_text_embedding import (
AzureAIInferenceTextEmbedding,
)
from semantic_kernel.connectors.ai.open_ai.const import DEFAULT_AZURE_API_VERSION
from semantic_kernel.connectors.ai.open_ai.settings.azure_open_ai_settings import AzureOpenAISettings
from semantic_kernel.core_plugins.text_memory_plugin import TextMemoryPlugin
from semantic_kernel.kernel import Kernel
Expand All @@ -21,14 +22,14 @@ async def test_azure_ai_inference_embedding_service(kernel: Kernel):
azure_openai_settings = AzureOpenAISettings.create()
endpoint = azure_openai_settings.endpoint
deployment_name = azure_openai_settings.embedding_deployment_name
api_key = azure_openai_settings.api_key.get_secret_value()

embeddings_gen = AzureAIInferenceTextEmbedding(
ai_model_id=deployment_name,
client=EmbeddingsClient(
endpoint=f'{str(endpoint).strip("/")}/openai/deployments/{deployment_name}',
credential=AzureKeyCredential(""),
headers={"api-key": api_key},
credential=DefaultAzureCredential(),
credential_scopes=["https://cognitiveservices.azure.com/.default"],
api_version=DEFAULT_AZURE_API_VERSION,
),
)

Expand Down

0 comments on commit 40e4c1c

Please sign in to comment.