From 0159fcd18a23d0944dcb2e5e632f393efdc97dd7 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 11 Sep 2024 14:23:24 -0700 Subject: [PATCH 1/4] Cerebras Integration --- .github/workflows/contrib-tests.yml | 40 ++ autogen/logger/file_logger.py | 2 + autogen/logger/sqlite_logger.py | 2 + autogen/oai/cerebras.py | 263 +++++++++ autogen/oai/client.py | 12 + autogen/runtime_logging.py | 2 + .../autogen-studio/autogenstudio/datamodel.py | 1 + setup.py | 1 + test/oai/test_cerebras.py | 241 +++++++++ .../non-openai-models/cloud-cerebras.ipynb | 510 ++++++++++++++++++ 10 files changed, 1074 insertions(+) create mode 100644 autogen/oai/cerebras.py create mode 100644 test/oai/test_cerebras.py create mode 100644 website/docs/topics/non-openai-models/cloud-cerebras.ipynb diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index d58098c98e76..c8afaddc8506 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -474,6 +474,46 @@ jobs: file: ./coverage.xml flags: unittests + CerebrasTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.9", "3.10", "3.11", "3.12"] + exclude: + - os: macos-latest + python-version: "3.9" + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Install packages and dependencies for Cerebras + run: | + pip install -e .[cerebras_cloud_sdk,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/oai/test_cerebras.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests + MistralTest: runs-on: ${{ matrix.os }} strategy: diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 07c9c3b76a76..329510894920 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -19,6 +19,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.bedrock import BedrockClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -210,6 +211,7 @@ def log_new_client( client: ( AzureOpenAI | OpenAI + | CerebrasClient | GeminiClient | AnthropicClient | MistralAIClient diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index f76d039ce9de..98538f5e22f7 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.bedrock import BedrockClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient @@ -397,6 +398,7 @@ def log_new_client( client: Union[ AzureOpenAI, OpenAI, + CerebrasClient, GeminiClient, AnthropicClient, MistralAIClient, diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py new file mode 100644 index 000000000000..92f355583a5e --- /dev/null +++ b/autogen/oai/cerebras.py @@ -0,0 +1,263 @@ +"""Create an OpenAI-compatible client using Cerebras's API. + +Example: + llm_config={ + "config_list": [{ + "api_type": "cerebras", + "model": "llama3.1-8b", + "api_key": os.environ.get("CEREBRAS_API_KEY") + }] + } + + agent = autogen.AssistantAgent("my_agent", llm_config=llm_config) + +Install Cerebras's python library using: pip install --upgrade cerebras_cloud_sdk + +Resources: +- https://inference-docs.cerebras.ai/quickstart +""" + +from __future__ import annotations + +import copy +import os +import time +import warnings +from typing import Any, Dict, List + +from cerebras.cloud.sdk import Cerebras, Stream +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import ChatCompletionMessage, Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import should_hide_tools, validate_parameter + +CEREBRAS_PRICING_1K = { + # Convert pricing per million to per thousand tokens. + "llama3.1-8b": (0.10 / 1000, 0.10 / 1000), + "llama3.1-70b": (0.60 / 1000, 0.60 / 1000), +} + +class CerebrasClient: + """Client for Cerebras's API.""" + + def __init__(self, **kwargs): + """Requires api_key or environment variable to be set + + Args: + api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set) + """ + # Ensure we have the api_key upon instantiation + self.api_key = kwargs.get("api_key", None) + if not self.api_key: + self.api_key = os.getenv("CEREBRAS_API_KEY") + + assert ( + self.api_key + ), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." + + def message_retrieval(self, response) -> List: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + return [choice.message for choice in response.choices] + + def cost(self, response) -> float: + return response.cost + + @staticmethod + def get_usage(response) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + # ... # pragma: no cover + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Loads the parameters for Cerebras API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" + cerebras_params = {} + + # Check that we have what we need to use Cerebras's API + # We won't enforce the available models as they are likely to change + cerebras_params["model"] = params.get("model", None) + assert cerebras_params[ + "model" + ], "Please specify the 'model' in your config list entry to nominate the Cerebras model to use." + + # Validate allowed Cerebras parameters + # https://inference-docs.cerebras.ai/api-reference/chat-completions + cerebras_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None) + cerebras_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None) + cerebras_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None) + cerebras_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 1.5), None) + cerebras_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None) + + return cerebras_params + + def create(self, params: Dict) -> ChatCompletion: + + messages = params.get("messages", []) + + # Convert AutoGen messages to Cerebras messages + cerebras_messages = oai_messages_to_cerebras_messages(messages) + + # Parse parameters to the Cerebras API's parameters + cerebras_params = self.parse_params(params) + + # Add tools to the call if we have them and aren't hiding them + if "tools" in params: + hide_tools = validate_parameter( + params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"] + ) + if not should_hide_tools(cerebras_messages, params["tools"], hide_tools): + cerebras_params["tools"] = params["tools"] + + cerebras_params["messages"] = cerebras_messages + + # We use chat model by default, and set max_retries to 5 (in line with typical retries loop) + client = Cerebras(api_key=self.api_key, max_retries=5) + + # Token counts will be returned + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + # Streaming tool call recommendations + streaming_tool_calls = [] + + ans = None + try: + response = client.chat.completions.create(**cerebras_params) + except Exception as e: + raise RuntimeError(f"Cerebras exception occurred: {e}") + else: + + if cerebras_params["stream"]: + # Read in the chunks as they stream, taking in tool_calls which may be across + # multiple chunks if more than one suggested + ans = "" + for chunk in response: + ans = ans + (chunk.choices[0].delta.content or "") + + if chunk.choices[0].delta.tool_calls: + # We have a tool call recommendation + for tool_call in chunk.choices[0].delta.tool_calls: + streaming_tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={ + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + type="function", + ) + ) + + if chunk.choices[0].finish_reason: + prompt_tokens = chunk.x_cerebras.usage.prompt_tokens + completion_tokens = chunk.x_cerebras.usage.completion_tokens + total_tokens = chunk.x_cerebras.usage.total_tokens + else: + # Non-streaming finished + ans: str = response.choices[0].message.content + + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + total_tokens = response.usage.total_tokens + + if response is not None: + if isinstance(response, Stream): + # Streaming response + if chunk.choices[0].finish_reason == "tool_calls": + cerebras_finish = "tool_calls" + tool_calls = streaming_tool_calls + else: + cerebras_finish = "stop" + tool_calls = None + + response_content = ans + response_id = chunk.id + else: + # Non-streaming response + # If we have tool calls as the response, populate completed tool calls for our return OAI response + if response.choices[0].finish_reason == "tool_calls": + cerebras_finish = "tool_calls" + tool_calls = [] + for tool_call in response.choices[0].message.tool_calls: + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={"name": tool_call.function.name, "arguments": tool_call.function.arguments}, + type="function", + ) + ) + else: + cerebras_finish = "stop" + tool_calls = None + + response_content = response.choices[0].message.content + response_id = response.id + else: + raise RuntimeError("Failed to get response from Cerebras after retrying 5 times.") + + # 3. convert output + message = ChatCompletionMessage( + role="assistant", + content=response_content, + function_call=None, + tool_calls=tool_calls, + ) + choices = [Choice(finish_reason=cerebras_finish, index=0, message=message)] + + response_oai = ChatCompletion( + id=response_id, + model=cerebras_params["model"], + created=int(time.time()), + object="chat.completion", + choices=choices, + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]), + ) + + return response_oai + + +def oai_messages_to_cerebras_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + """Convert messages from OAI format to Cerebras's format. + We correct for any specific role orders and types. + """ + + cerebras_messages = copy.deepcopy(messages) + + # Remove the name field + for message in cerebras_messages: + if "name" in message: + message.pop("name", None) + + return cerebras_messages + + +def calculate_cerebras_cost(input_tokens: int, output_tokens: int, model: str) -> float: + """Calculate the cost of the completion using the Cerebras pricing.""" + total = 0.0 + + if model in CEREBRAS_PRICING_1K: + input_cost_per_k, output_cost_per_k = CEREBRAS_PRICING_1K[model] + input_cost = (input_tokens / 1000) * input_cost_per_k + output_cost = (output_tokens / 1000) * output_cost_per_k + total = input_cost + output_cost + else: + warnings.warn(f"Cost calculation not available for model {model}", UserWarning) + + return total diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 4b77815e7eb7..1748b28a7a1f 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -44,6 +44,13 @@ TOOL_ENABLED = True ERROR = None +try: + from autogen.oai.cerebras import CerebrasClient + + cerebras_import_exception: Optional[ImportError] = None +except ImportError as e: + cerebras_import_exception = e + try: from autogen.oai.gemini import GeminiClient @@ -505,6 +512,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s self._configure_azure_openai(config, openai_config) client = AzureOpenAI(**openai_config) self._clients.append(OpenAIClient(client)) + elif api_type is not None and api_type.startswith("cerebras"): + if cerebras_import_exception: + raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.") + client = CerebrasClient(**openai_config) + self._clients.append(client) elif api_type is not None and api_type.startswith("google"): if gemini_import_exception: raise ImportError("Please install `google-generativeai` to use Google OpenAI API.") diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 0fd7cc2fc8b9..9036fe5c65cc 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -15,6 +15,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.bedrock import BedrockClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -116,6 +117,7 @@ def log_new_client( client: Union[ AzureOpenAI, OpenAI, + CerebrasClient, GeminiClient, AnthropicClient, MistralAIClient, diff --git a/samples/apps/autogen-studio/autogenstudio/datamodel.py b/samples/apps/autogen-studio/autogenstudio/datamodel.py index 92d60cf5c525..ee48818d599f 100644 --- a/samples/apps/autogen-studio/autogenstudio/datamodel.py +++ b/samples/apps/autogen-studio/autogenstudio/datamodel.py @@ -126,6 +126,7 @@ class LLMConfig(SQLModel, table=False): class ModelTypes(str, Enum): openai = "open_ai" + cerebras = "cerebras" google = "google" azure = "azure" anthropic = "anthropic" diff --git a/setup.py b/setup.py index 362aa1217986..63a9c9745ffb 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,7 @@ "types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor, "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], + "cerebras": ["cerebras_cloud_sdk>=1.0.0"], "mistral": ["mistralai>=1.0.1"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], diff --git a/test/oai/test_cerebras.py b/test/oai/test_cerebras.py new file mode 100644 index 000000000000..7374c23f2ebe --- /dev/null +++ b/test/oai/test_cerebras.py @@ -0,0 +1,241 @@ +from unittest.mock import MagicMock, patch + +import pytest + +try: + from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost + + skip = False +except ImportError: + CerebrasClient = object + InternalServerError = object + skip = True + + +# Fixtures for mock data +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, text, choices, usage, cost, model): + self.text = text + self.choices = choices + self.usage = usage + self.cost = cost + self.model = model + + return MockResponse + + +@pytest.fixture +def cerebras_client(): + return CerebrasClient(api_key="fake_api_key") + + +skip_reason = "Cerebras dependency is not installed" + + +# Test initialization and configuration +@pytest.mark.skipif(skip, reason=skip_reason) +def test_initialization(): + + # Missing any api_key + with pytest.raises(AssertionError) as assertinfo: + CerebrasClient() # Should raise an AssertionError due to missing api_key + + assert "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." in str( + assertinfo.value + ) + + # Creation works + CerebrasClient(api_key="fake_api_key") # Should create okay now. + + +# Test standard initialization +@pytest.mark.skipif(skip, reason=skip_reason) +def test_valid_initialization(cerebras_client): + assert cerebras_client.api_key == "fake_api_key", "Config api_key should be correctly set" + + +# Test parameters +@pytest.mark.skipif(skip, reason=skip_reason) +def test_parsing_params(cerebras_client): + # All parameters + params = { + "model": "llama3.1-8b", + "max_tokens": 1000, + "seed": 42, + "stream": False, + "temperature": 1, + "top_p": 0.8, + } + expected_params = { + "model": "llama3.1-8b", + "max_tokens": 1000, + "seed": 42, + "stream": False, + "temperature": 1, + "top_p": 0.8, + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # Only model, others set as defaults + params = { + "model": "llama3.1-8b", + } + expected_params = { + "model": "llama3.1-8b", + "max_tokens": None, + "seed": None, + "stream": False, + "temperature": 1, + "top_p": None, + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # Incorrect types, defaults should be set, will show warnings but not trigger assertions + params = { + "model": "llama3.1-8b", + "max_tokens": "1000", + "seed": "42", + "stream": "False", + "temperature": "1", + "top_p": "0.8", + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # Values outside bounds, should warn and set to defaults + params = { + "model": "llama3.1-8b", + "temperature": 33123, + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # No model + params = { + "temperature": 1, + } + + with pytest.raises(AssertionError) as assertinfo: + result = cerebras_client.parse_params(params) + + assert "Please specify the 'model' in your config list entry to nominate the Cerebras model to use." in str( + assertinfo.value + ) + + +# Test cost calculation +@pytest.mark.skipif(skip, reason=skip_reason) +def test_cost_calculation(mock_response): + response = mock_response( + text="Example response", + choices=[{"message": "Test message 1"}], + usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800}, + cost=None, + model="llama3.1-70b", + ) + calculated_cost = calculate_cerebras_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model) + + # Convert cost per milliion to cost per token. + expected_cost = response.usage["prompt_tokens"] * 0.6 / 1000000 + response.usage["completion_tokens"] * 0.6 / 1000000 + + assert calculated_cost == expected_cost, f"Cost for this should be ${expected_cost} but got ${calculated_cost}" + + +# Test text generation +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.cerebras.CerebrasClient.create") +def test_create_response(mock_chat, cerebras_client): + # Mock CerebrasClient.chat response + mock_cerebras_response = MagicMock() + mock_cerebras_response.choices = [ + MagicMock(finish_reason="stop", message=MagicMock(content="Example Cerebras response", tool_calls=None)) + ] + mock_cerebras_response.id = "mock_cerebras_response_id" + mock_cerebras_response.model = "llama3.1-70b" + mock_cerebras_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage + + mock_chat.return_value = mock_cerebras_response + + # Test parameters + params = { + "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}], + "model": "llama3.1-70b", + } + + # Call the create method + response = cerebras_client.create(params) + + # Assertions to check if response is structured as expected + assert ( + response.choices[0].message.content == "Example Cerebras response" + ), "Response content should match expected output" + assert response.id == "mock_cerebras_response_id", "Response ID should match the mocked response ID" + assert response.model == "llama3.1-70b", "Response model should match the mocked response model" + assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage" + assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage" + + +# Test functions/tools +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.cerebras.CerebrasClient.create") +def test_create_response_with_tool_call(mock_chat, cerebras_client): + # Mock `cerebras_response = client.chat(**cerebras_params)` + mock_function = MagicMock(name="currency_calculator") + mock_function.name = "currency_calculator" + mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}' + + mock_function_2 = MagicMock(name="get_weather") + mock_function_2.name = "get_weather" + mock_function_2.arguments = '{"location": "Chicago"}' + + mock_chat.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="tool_calls", + message=MagicMock( + content="Sample text about the functions", + tool_calls=[ + MagicMock(id="gdRdrvnHh", function=mock_function), + MagicMock(id="abRdrvnHh", function=mock_function_2), + ], + ), + ) + ], + id="mock_cerebras_response_id", + model="llama3.1-70b", + usage=MagicMock(prompt_tokens=10, completion_tokens=20), + ) + + # Construct parameters + converted_functions = [ + { + "type": "function", + "function": { + "description": "Currency exchange calculator.", + "name": "currency_calculator", + "parameters": { + "type": "object", + "properties": { + "base_amount": {"type": "number", "description": "Amount of currency in base_currency"}, + }, + "required": ["base_amount"], + }, + }, + } + ] + cerebras_messages = [ + {"role": "user", "content": "How much is 123.45 EUR in USD?"}, + {"role": "assistant", "content": "World"}, + ] + + # Call the create method + response = cerebras_client.create({"messages": cerebras_messages, "tools": converted_functions, "model": "llama3.1-70b"}) + + # Assertions to check if the functions and content are included in the response + assert response.choices[0].message.content == "Sample text about the functions" + assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator" + assert response.choices[0].message.tool_calls[1].function.name == "get_weather" diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb new file mode 100644 index 000000000000..d7354671853e --- /dev/null +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -0,0 +1,510 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cerebras\n", + "\n", + "[Cerebras](https://cerebras.ai) has developed the world's largest and fastest AI processor, the Wafer-Scale Engine-3 (WSE-3). Notably, the CS-3 system can run large language models like Llama-3.1-8B and Llama-3.1-70B at extremely fast speeds, making it an ideal platform for demanding AI workloads.\n", + "\n", + "While it's technically possible to adapt AutoGen to work with Cerebras' API by updating the `base_url`, this approach may not fully account for minor differences in parameter support. Using this library will also allow for tracking of the API costs based on actual token usage.\n", + "\n", + "For more information about Cerebras Cloud, visit [cloud.cerebras.ai](https://cloud.cerebras.ai). Their API reference is available at [inference-docs.cerebras.ai](https://inference-docs.cerebras.ai)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Requirements\n", + "To use Cerebras with AutoGen, install the `pyautogen[cerebras]` package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install pyautogen[\"cerebras\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started\n", + "\n", + "Cerebras provides a number of models to use, included below. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n", + "\n", + "See the sample `OAI_CONFIG_LIST` below showing how the Cerebras AI client class is used by specifying the `api_type` as `cerebras`.\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1-8b\",\n", + " \"api_key\": \"your Cerebras API Key goes here\",\n", + " \"api_type\": \"cerebras\"\n", + " },\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": \"your Cerebras API Key goes here\",\n", + " \"api_type\": \"cerebras\"\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Credentials\n", + "\n", + "Get an API Key from [cloud.cerebras.ai](https://cloud.cerebras.ai/) and add it to your environment variables:\n", + "\n", + "```\n", + "export CEREBRAS_API_KEY=\"your-api-key-here\"\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API parameters\n", + "\n", + "The following parameters can be added to your config for the Cerebras API. See [this link](https://inference-docs.cerebras.ai/api-reference/chat-completions) for further information on them and their default values.\n", + "\n", + "- max_tokens (null, integer >= 0)\n", + "- seed (number)\n", + "- stream (True or False)\n", + "- temperature (number 0..1.5)\n", + "- top_p (number)\n", + "\n", + "Example:\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": \"your Cerebras API Key goes here\",\n", + " \"api_type\": \"cerebras\"\n", + " \"max_tokens\": 10000,\n", + " \"seed\": 1234,\n", + " \"stream\" True,\n", + " \"temperature\": 0.5,\n", + " \"top_p\": 0.2, # Note: It is recommended to set temperature or top_p but not both.\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-Agent Coding Example\n", + "\n", + "In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n", + "\n", + "We'll use Meta's LLama-3.1-70B model which is suitable for coding." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost\n", + "\n", + "config_list = [\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"),\n", + " \"api_type\": \"cerebras\"\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from autogen.coding import LocalCommandLineCodeExecutor\n", + "\n", + "# Setting up the code executor\n", + "workdir = Path(\"coding\")\n", + "workdir.mkdir(exist_ok=True)\n", + "code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n", + "\n", + "# Setting up the agents\n", + "\n", + "# The UserProxyAgent will execute the code that the AssistantAgent provides\n", + "user_proxy_agent = UserProxyAgent(\n", + " name=\"User\",\n", + " code_execution_config={\"executor\": code_executor},\n", + " is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n", + ")\n", + "\n", + "system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n", + "Solve tasks using your coding and language skills.\n", + "In the following cases, suggest python code (in a python coding block) for the user to execute.\n", + "Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n", + "When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n", + "Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n", + "If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n", + "When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n", + "IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n", + "\n", + "# The AssistantAgent, using Cerebras AI's model, will take the coding request and return code\n", + "assistant_agent = AssistantAgent(\n", + " name=\"Cerebras Assistant\",\n", + " system_message=system_message,\n", + " llm_config={\"config_list\": config_list},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser\u001b[0m (to Cerebras Assistant):\n", + "\n", + "Provide code to count the number of prime numbers from 1 to 10000.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mCerebras Assistant\u001b[0m (to User):\n", + "\n", + "To count the number of prime numbers from 1 to 10000, we will utilize a simple algorithm that checks each number in the range to see if it is prime. A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.\n", + "\n", + "Here's how we can do it using a Python script:\n", + "\n", + "```python\n", + "def count_primes(n):\n", + " primes = 0\n", + " for possiblePrime in range(2, n + 1):\n", + " # Assume number is prime until shown it is not. \n", + " isPrime = True\n", + " for num in range(2, int(possiblePrime ** 0.5) + 1):\n", + " if possiblePrime % num == 0:\n", + " isPrime = False\n", + " break\n", + " if isPrime:\n", + " primes += 1\n", + " return primes\n", + "\n", + "# Counting prime numbers from 1 to 10000\n", + "count = count_primes(10000)\n", + "print(count)\n", + "```\n", + "\n", + "Please execute this code. I will respond with \"FINISH\" after you provide the result.\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + "Replying as User. Provide feedback to Cerebras Assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" + ] + } + ], + "source": [ + "# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n", + "chat_result = user_proxy_agent.initiate_chat(\n", + " assistant_agent,\n", + " message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Call Example\n", + "\n", + "In this example, instead of writing code, we will show how Meta's Llama-3.1-70B model can perform parallel tool calling, where it recommends calling more than one tool at a time.\n", + "\n", + "We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n", + "\n", + "We start by importing libraries and setting up our configuration to use Llama-3.1-8B and the `cerebras` client class." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from typing import Literal\n", + "\n", + "from typing_extensions import Annotated\n", + "\n", + "import autogen\n", + "\n", + "config_list = [\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"),\n", + " \"api_type\": \"cerebras\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create our two agents." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the agent for tool calling\n", + "chatbot = autogen.AssistantAgent(\n", + " name=\"chatbot\",\n", + " system_message=\"\"\"\n", + " For currency exchange and weather forecasting tasks,\n", + " only use the functions you have been provided with.\n", + " When you summarize, make sure you've considered ALL previous instructions.\n", + " Output 'HAVE FUN!' when an answer has been provided.\n", + " \"\"\",\n", + " llm_config={\"config_list\": config_list},\n", + ")\n", + "\n", + "# Note that we have changed the termination string to be \"HAVE FUN!\"\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n", + "\n", + "We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "# Currency Exchange function\n", + "\n", + "CurrencySymbol = Literal[\"USD\", \"EUR\"]\n", + "\n", + "# Define our function that we expect to call\n", + "\n", + "\n", + "def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n", + " if base_currency == quote_currency:\n", + " return 1.0\n", + " elif base_currency == \"USD\" and quote_currency == \"EUR\":\n", + " return 1 / 1.1\n", + " elif base_currency == \"EUR\" and quote_currency == \"USD\":\n", + " return 1.1\n", + " else:\n", + " raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", + "def currency_calculator(\n", + " base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n", + " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", + " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", + ") -> str:\n", + " quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n", + " return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n", + "\n", + "\n", + "# Weather function\n", + "\n", + "\n", + "# Example function to make available to model\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the weather for some location\"\"\"\n", + " if \"chicago\" in location.lower():\n", + " return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n", + " elif \"new york\" in location.lower():\n", + " return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n", + "def weather_forecast(\n", + " location: Annotated[str, \"City name\"],\n", + ") -> str:\n", + " weather_details = get_current_weather(location=location)\n", + " weather = json.loads(weather_details)\n", + " return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We pass through our customer's message and run the chat.\n", + "\n", + "Finally, we ask the LLM to summarise the chat and print that out." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\u001b[32m***** Suggested tool call (210f6ac6d): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\": \"New York\"}\n", + "\u001b[32m*************************************************************\u001b[0m\n", + "\u001b[32m***** Suggested tool call (3c00ac7d5): currency_calculator *****\u001b[0m\n", + "Arguments: \n", + "{\"base_amount\": 123.45, \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n", + "\u001b[32m****************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (210f6ac6d) *****\u001b[0m\n", + "New York will be 11 degrees fahrenheit\n", + "\u001b[32m**************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (3c00ac7d5) *****\u001b[0m\n", + "135.80 USD\n", + "\u001b[32m**************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "New York will be 11 degrees fahrenheit.\n", + "123.45 EUR is equivalent to 135.80 USD.\n", + " \n", + "For a great holiday, explore the Statue of Liberty, take a walk through Central Park, or visit one of the many world-class museums. Also, you'll find great food ranging from bagels to fine dining experiences. HAVE FUN!\n", + "\n", + "--------------------------------------------------------------------------------\n", + "LLM SUMMARY: New York will be 11 degrees fahrenheit. 123.45 EUR is equivalent to 135.80 USD. Explore the Statue of Liberty, walk through Central Park, or visit one of the many world-class museums for a great holiday in New York.\n", + "\n", + "Duration: 73.97937774658203ms\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "start_time = time.time()\n", + "\n", + "# start the conversation\n", + "res = user_proxy.initiate_chat(\n", + " chatbot,\n", + " message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n", + " summary_method=\"reflection_with_llm\",\n", + ")\n", + "\n", + "end_time = time.time()\n", + "\n", + "print(f\"LLM SUMMARY: {res.summary['content']}\\n\\nDuration: {(end_time - start_time) * 1000}ms\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the Cerebras Wafer-Scale Engine-3 (WSE-3) completed the query in 74ms -- faster than the blink of an eye!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From bb95583c6b7f01f253250f292ba860db7a406b2f Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 26 Sep 2024 13:56:47 -0700 Subject: [PATCH 2/4] Address feedback --- autogen/oai/cerebras.py | 15 ++++++++++----- .../topics/non-openai-models/cloud-cerebras.ipynb | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index 92f355583a5e..c5e153c54657 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -38,17 +38,18 @@ "llama3.1-70b": (0.60 / 1000, 0.60 / 1000), } + class CerebrasClient: """Client for Cerebras's API.""" - def __init__(self, **kwargs): + def __init__(self, api_key = None, **kwargs): """Requires api_key or environment variable to be set Args: api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set) """ # Ensure we have the api_key upon instantiation - self.api_key = kwargs.get("api_key", None) + self.api_key = api_key if not self.api_key: self.api_key = os.getenv("CEREBRAS_API_KEY") @@ -56,7 +57,7 @@ def __init__(self, **kwargs): self.api_key ), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." - def message_retrieval(self, response) -> List: + def message_retrieval(self, response: ChatCompletion) -> List: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -65,11 +66,12 @@ def message_retrieval(self, response) -> List: """ return [choice.message for choice in response.choices] - def cost(self, response) -> float: + def cost(self, response: ChatCompletion) -> float: + # Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation. return response.cost @staticmethod - def get_usage(response) -> Dict: + def get_usage(response: ChatCompletion) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" # ... # pragma: no cover return { @@ -144,6 +146,7 @@ def create(self, params: Dict) -> ChatCompletion: # multiple chunks if more than one suggested ans = "" for chunk in response: + # Grab first choice, which _should_ always be generated. ans = ans + (chunk.choices[0].delta.content or "") if chunk.choices[0].delta.tool_calls: @@ -227,6 +230,8 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=total_tokens, ), + # Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic + # just adds it dynamically. cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]), ) diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb index d7354671853e..73c3cdafa4cb 100644 --- a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -36,7 +36,7 @@ "source": [ "## Getting Started\n", "\n", - "Cerebras provides a number of models to use, included below. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n", + "Cerebras provides a number of models to use. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n", "\n", "See the sample `OAI_CONFIG_LIST` below showing how the Cerebras AI client class is used by specifying the `api_type` as `cerebras`.\n", "```python\n", @@ -223,7 +223,7 @@ ] }, { - "name": "stdin", + "name": "stdout", "output_type": "stream", "text": [ "Replying as User. Provide feedback to Cerebras Assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: \n" From 6bc31adf83032a74161d972e0a339b60c088520c Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Fri, 27 Sep 2024 15:48:59 -0400 Subject: [PATCH 3/4] Fix typo --- website/docs/topics/non-openai-models/cloud-cerebras.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb index 73c3cdafa4cb..b264e0cfcc47 100644 --- a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -256,7 +256,7 @@ "\n", "We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n", "\n", - "We start by importing libraries and setting up our configuration to use Llama-3.1-8B and the `cerebras` client class." + "We start by importing libraries and setting up our configuration to use Llama-3.1-70B and the `cerebras` client class." ] }, { From 5000da886764f196ba9d0d3f979f01bc02d4fe50 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Mon, 30 Sep 2024 12:56:49 -0700 Subject: [PATCH 4/4] Run formatter --- autogen/logger/sqlite_logger.py | 2 +- autogen/oai/cerebras.py | 6 ++++-- test/oai/test_cerebras.py | 19 +++++++++++++------ .../non-openai-models/cloud-cerebras.ipynb | 9 ++------- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index 98538f5e22f7..3849c19711c7 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -19,8 +19,8 @@ if TYPE_CHECKING: from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient - from autogen.oai.cerebras import CerebrasClient from autogen.oai.bedrock import BedrockClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py index c5e153c54657..e87b048e1366 100644 --- a/autogen/oai/cerebras.py +++ b/autogen/oai/cerebras.py @@ -42,7 +42,7 @@ class CerebrasClient: """Client for Cerebras's API.""" - def __init__(self, api_key = None, **kwargs): + def __init__(self, api_key=None, **kwargs): """Requires api_key or environment variable to be set Args: @@ -98,7 +98,9 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: cerebras_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None) cerebras_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None) cerebras_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None) - cerebras_params["temperature"] = validate_parameter(params, "temperature", (int, float), True, 1, (0, 1.5), None) + cerebras_params["temperature"] = validate_parameter( + params, "temperature", (int, float), True, 1, (0, 1.5), None + ) cerebras_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None) return cerebras_params diff --git a/test/oai/test_cerebras.py b/test/oai/test_cerebras.py index 7374c23f2ebe..7f84ae3f9d56 100644 --- a/test/oai/test_cerebras.py +++ b/test/oai/test_cerebras.py @@ -42,8 +42,9 @@ def test_initialization(): with pytest.raises(AssertionError) as assertinfo: CerebrasClient() # Should raise an AssertionError due to missing api_key - assert "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." in str( - assertinfo.value + assert ( + "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." + in str(assertinfo.value) ) # Creation works @@ -137,11 +138,15 @@ def test_cost_calculation(mock_response): cost=None, model="llama3.1-70b", ) - calculated_cost = calculate_cerebras_cost(response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model) + calculated_cost = calculate_cerebras_cost( + response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model + ) # Convert cost per milliion to cost per token. - expected_cost = response.usage["prompt_tokens"] * 0.6 / 1000000 + response.usage["completion_tokens"] * 0.6 / 1000000 - + expected_cost = ( + response.usage["prompt_tokens"] * 0.6 / 1000000 + response.usage["completion_tokens"] * 0.6 / 1000000 + ) + assert calculated_cost == expected_cost, f"Cost for this should be ${expected_cost} but got ${calculated_cost}" @@ -233,7 +238,9 @@ def test_create_response_with_tool_call(mock_chat, cerebras_client): ] # Call the create method - response = cerebras_client.create({"messages": cerebras_messages, "tools": converted_functions, "model": "llama3.1-70b"}) + response = cerebras_client.create( + {"messages": cerebras_messages, "tools": converted_functions, "model": "llama3.1-70b"} + ) # Assertions to check if the functions and content are included in the response assert response.choices[0].message.content == "Sample text about the functions" diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb index b264e0cfcc47..a8e1d3940f4b 100644 --- a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -117,15 +117,10 @@ "outputs": [], "source": [ "import os\n", + "\n", "from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost\n", "\n", - "config_list = [\n", - " {\n", - " \"model\": \"llama3.1-70b\",\n", - " \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"),\n", - " \"api_type\": \"cerebras\"\n", - " }\n", - "]" + "config_list = [{\"model\": \"llama3.1-70b\", \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"), \"api_type\": \"cerebras\"}]" ] }, {