diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index d58098c98e76..c8afaddc8506 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -474,6 +474,46 @@ jobs: file: ./coverage.xml flags: unittests + CerebrasTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-2019] + python-version: ["3.9", "3.10", "3.11", "3.12"] + exclude: + - os: macos-latest + python-version: "3.9" + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Install packages and dependencies for Cerebras + run: | + pip install -e .[cerebras_cloud_sdk,test] + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/oai/test_cerebras.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests + MistralTest: runs-on: ${{ matrix.os }} strategy: diff --git a/autogen/logger/file_logger.py b/autogen/logger/file_logger.py index 07c9c3b76a76..329510894920 100644 --- a/autogen/logger/file_logger.py +++ b/autogen/logger/file_logger.py @@ -19,6 +19,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.bedrock import BedrockClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -210,6 +211,7 @@ def log_new_client( client: ( AzureOpenAI | OpenAI + | CerebrasClient | GeminiClient | AnthropicClient | MistralAIClient diff --git a/autogen/logger/sqlite_logger.py b/autogen/logger/sqlite_logger.py index f76d039ce9de..3849c19711c7 100644 --- a/autogen/logger/sqlite_logger.py +++ b/autogen/logger/sqlite_logger.py @@ -20,6 +20,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.bedrock import BedrockClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -397,6 +398,7 @@ def log_new_client( client: Union[ AzureOpenAI, OpenAI, + CerebrasClient, GeminiClient, AnthropicClient, MistralAIClient, diff --git a/autogen/oai/cerebras.py b/autogen/oai/cerebras.py new file mode 100644 index 000000000000..e87b048e1366 --- /dev/null +++ b/autogen/oai/cerebras.py @@ -0,0 +1,270 @@ +"""Create an OpenAI-compatible client using Cerebras's API. + +Example: + llm_config={ + "config_list": [{ + "api_type": "cerebras", + "model": "llama3.1-8b", + "api_key": os.environ.get("CEREBRAS_API_KEY") + }] + } + + agent = autogen.AssistantAgent("my_agent", llm_config=llm_config) + +Install Cerebras's python library using: pip install --upgrade cerebras_cloud_sdk + +Resources: +- https://inference-docs.cerebras.ai/quickstart +""" + +from __future__ import annotations + +import copy +import os +import time +import warnings +from typing import Any, Dict, List + +from cerebras.cloud.sdk import Cerebras, Stream +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion import ChatCompletionMessage, Choice +from openai.types.completion_usage import CompletionUsage + +from autogen.oai.client_utils import should_hide_tools, validate_parameter + +CEREBRAS_PRICING_1K = { + # Convert pricing per million to per thousand tokens. + "llama3.1-8b": (0.10 / 1000, 0.10 / 1000), + "llama3.1-70b": (0.60 / 1000, 0.60 / 1000), +} + + +class CerebrasClient: + """Client for Cerebras's API.""" + + def __init__(self, api_key=None, **kwargs): + """Requires api_key or environment variable to be set + + Args: + api_key (str): The API key for using Cerebras (or environment variable CEREBRAS_API_KEY needs to be set) + """ + # Ensure we have the api_key upon instantiation + self.api_key = api_key + if not self.api_key: + self.api_key = os.getenv("CEREBRAS_API_KEY") + + assert ( + self.api_key + ), "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." + + def message_retrieval(self, response: ChatCompletion) -> List: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + return [choice.message for choice in response.choices] + + def cost(self, response: ChatCompletion) -> float: + # Note: This field isn't explicitly in `ChatCompletion`, but is injected during chat creation. + return response.cost + + @staticmethod + def get_usage(response: ChatCompletion) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + # ... # pragma: no cover + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """Loads the parameters for Cerebras API from the passed in parameters and returns a validated set. Checks types, ranges, and sets defaults""" + cerebras_params = {} + + # Check that we have what we need to use Cerebras's API + # We won't enforce the available models as they are likely to change + cerebras_params["model"] = params.get("model", None) + assert cerebras_params[ + "model" + ], "Please specify the 'model' in your config list entry to nominate the Cerebras model to use." + + # Validate allowed Cerebras parameters + # https://inference-docs.cerebras.ai/api-reference/chat-completions + cerebras_params["max_tokens"] = validate_parameter(params, "max_tokens", int, True, None, (0, None), None) + cerebras_params["seed"] = validate_parameter(params, "seed", int, True, None, None, None) + cerebras_params["stream"] = validate_parameter(params, "stream", bool, True, False, None, None) + cerebras_params["temperature"] = validate_parameter( + params, "temperature", (int, float), True, 1, (0, 1.5), None + ) + cerebras_params["top_p"] = validate_parameter(params, "top_p", (int, float), True, None, None, None) + + return cerebras_params + + def create(self, params: Dict) -> ChatCompletion: + + messages = params.get("messages", []) + + # Convert AutoGen messages to Cerebras messages + cerebras_messages = oai_messages_to_cerebras_messages(messages) + + # Parse parameters to the Cerebras API's parameters + cerebras_params = self.parse_params(params) + + # Add tools to the call if we have them and aren't hiding them + if "tools" in params: + hide_tools = validate_parameter( + params, "hide_tools", str, False, "never", None, ["if_all_run", "if_any_run", "never"] + ) + if not should_hide_tools(cerebras_messages, params["tools"], hide_tools): + cerebras_params["tools"] = params["tools"] + + cerebras_params["messages"] = cerebras_messages + + # We use chat model by default, and set max_retries to 5 (in line with typical retries loop) + client = Cerebras(api_key=self.api_key, max_retries=5) + + # Token counts will be returned + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + # Streaming tool call recommendations + streaming_tool_calls = [] + + ans = None + try: + response = client.chat.completions.create(**cerebras_params) + except Exception as e: + raise RuntimeError(f"Cerebras exception occurred: {e}") + else: + + if cerebras_params["stream"]: + # Read in the chunks as they stream, taking in tool_calls which may be across + # multiple chunks if more than one suggested + ans = "" + for chunk in response: + # Grab first choice, which _should_ always be generated. + ans = ans + (chunk.choices[0].delta.content or "") + + if chunk.choices[0].delta.tool_calls: + # We have a tool call recommendation + for tool_call in chunk.choices[0].delta.tool_calls: + streaming_tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={ + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + type="function", + ) + ) + + if chunk.choices[0].finish_reason: + prompt_tokens = chunk.x_cerebras.usage.prompt_tokens + completion_tokens = chunk.x_cerebras.usage.completion_tokens + total_tokens = chunk.x_cerebras.usage.total_tokens + else: + # Non-streaming finished + ans: str = response.choices[0].message.content + + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + total_tokens = response.usage.total_tokens + + if response is not None: + if isinstance(response, Stream): + # Streaming response + if chunk.choices[0].finish_reason == "tool_calls": + cerebras_finish = "tool_calls" + tool_calls = streaming_tool_calls + else: + cerebras_finish = "stop" + tool_calls = None + + response_content = ans + response_id = chunk.id + else: + # Non-streaming response + # If we have tool calls as the response, populate completed tool calls for our return OAI response + if response.choices[0].finish_reason == "tool_calls": + cerebras_finish = "tool_calls" + tool_calls = [] + for tool_call in response.choices[0].message.tool_calls: + tool_calls.append( + ChatCompletionMessageToolCall( + id=tool_call.id, + function={"name": tool_call.function.name, "arguments": tool_call.function.arguments}, + type="function", + ) + ) + else: + cerebras_finish = "stop" + tool_calls = None + + response_content = response.choices[0].message.content + response_id = response.id + else: + raise RuntimeError("Failed to get response from Cerebras after retrying 5 times.") + + # 3. convert output + message = ChatCompletionMessage( + role="assistant", + content=response_content, + function_call=None, + tool_calls=tool_calls, + ) + choices = [Choice(finish_reason=cerebras_finish, index=0, message=message)] + + response_oai = ChatCompletion( + id=response_id, + model=cerebras_params["model"], + created=int(time.time()), + object="chat.completion", + choices=choices, + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ), + # Note: This seems to be a field that isn't in the schema of `ChatCompletion`, so Pydantic + # just adds it dynamically. + cost=calculate_cerebras_cost(prompt_tokens, completion_tokens, cerebras_params["model"]), + ) + + return response_oai + + +def oai_messages_to_cerebras_messages(messages: list[Dict[str, Any]]) -> list[dict[str, Any]]: + """Convert messages from OAI format to Cerebras's format. + We correct for any specific role orders and types. + """ + + cerebras_messages = copy.deepcopy(messages) + + # Remove the name field + for message in cerebras_messages: + if "name" in message: + message.pop("name", None) + + return cerebras_messages + + +def calculate_cerebras_cost(input_tokens: int, output_tokens: int, model: str) -> float: + """Calculate the cost of the completion using the Cerebras pricing.""" + total = 0.0 + + if model in CEREBRAS_PRICING_1K: + input_cost_per_k, output_cost_per_k = CEREBRAS_PRICING_1K[model] + input_cost = (input_tokens / 1000) * input_cost_per_k + output_cost = (output_tokens / 1000) * output_cost_per_k + total = input_cost + output_cost + else: + warnings.warn(f"Cost calculation not available for model {model}", UserWarning) + + return total diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 4b77815e7eb7..1748b28a7a1f 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -44,6 +44,13 @@ TOOL_ENABLED = True ERROR = None +try: + from autogen.oai.cerebras import CerebrasClient + + cerebras_import_exception: Optional[ImportError] = None +except ImportError as e: + cerebras_import_exception = e + try: from autogen.oai.gemini import GeminiClient @@ -505,6 +512,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s self._configure_azure_openai(config, openai_config) client = AzureOpenAI(**openai_config) self._clients.append(OpenAIClient(client)) + elif api_type is not None and api_type.startswith("cerebras"): + if cerebras_import_exception: + raise ImportError("Please install `cerebras_cloud_sdk` to use Cerebras OpenAI API.") + client = CerebrasClient(**openai_config) + self._clients.append(client) elif api_type is not None and api_type.startswith("google"): if gemini_import_exception: raise ImportError("Please install `google-generativeai` to use Google OpenAI API.") diff --git a/autogen/runtime_logging.py b/autogen/runtime_logging.py index 0fd7cc2fc8b9..9036fe5c65cc 100644 --- a/autogen/runtime_logging.py +++ b/autogen/runtime_logging.py @@ -15,6 +15,7 @@ from autogen import Agent, ConversableAgent, OpenAIWrapper from autogen.oai.anthropic import AnthropicClient from autogen.oai.bedrock import BedrockClient + from autogen.oai.cerebras import CerebrasClient from autogen.oai.cohere import CohereClient from autogen.oai.gemini import GeminiClient from autogen.oai.groq import GroqClient @@ -116,6 +117,7 @@ def log_new_client( client: Union[ AzureOpenAI, OpenAI, + CerebrasClient, GeminiClient, AnthropicClient, MistralAIClient, diff --git a/samples/apps/autogen-studio/autogenstudio/datamodel.py b/samples/apps/autogen-studio/autogenstudio/datamodel.py index 92d60cf5c525..ee48818d599f 100644 --- a/samples/apps/autogen-studio/autogenstudio/datamodel.py +++ b/samples/apps/autogen-studio/autogenstudio/datamodel.py @@ -126,6 +126,7 @@ class LLMConfig(SQLModel, table=False): class ModelTypes(str, Enum): openai = "open_ai" + cerebras = "cerebras" google = "google" azure = "azure" anthropic = "anthropic" diff --git a/setup.py b/setup.py index 362aa1217986..63a9c9745ffb 100644 --- a/setup.py +++ b/setup.py @@ -101,6 +101,7 @@ "types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor, "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], + "cerebras": ["cerebras_cloud_sdk>=1.0.0"], "mistral": ["mistralai>=1.0.1"], "groq": ["groq>=0.9.0"], "cohere": ["cohere>=5.5.8"], diff --git a/test/oai/test_cerebras.py b/test/oai/test_cerebras.py new file mode 100644 index 000000000000..7f84ae3f9d56 --- /dev/null +++ b/test/oai/test_cerebras.py @@ -0,0 +1,248 @@ +from unittest.mock import MagicMock, patch + +import pytest + +try: + from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost + + skip = False +except ImportError: + CerebrasClient = object + InternalServerError = object + skip = True + + +# Fixtures for mock data +@pytest.fixture +def mock_response(): + class MockResponse: + def __init__(self, text, choices, usage, cost, model): + self.text = text + self.choices = choices + self.usage = usage + self.cost = cost + self.model = model + + return MockResponse + + +@pytest.fixture +def cerebras_client(): + return CerebrasClient(api_key="fake_api_key") + + +skip_reason = "Cerebras dependency is not installed" + + +# Test initialization and configuration +@pytest.mark.skipif(skip, reason=skip_reason) +def test_initialization(): + + # Missing any api_key + with pytest.raises(AssertionError) as assertinfo: + CerebrasClient() # Should raise an AssertionError due to missing api_key + + assert ( + "Please include the api_key in your config list entry for Cerebras or set the CEREBRAS_API_KEY env variable." + in str(assertinfo.value) + ) + + # Creation works + CerebrasClient(api_key="fake_api_key") # Should create okay now. + + +# Test standard initialization +@pytest.mark.skipif(skip, reason=skip_reason) +def test_valid_initialization(cerebras_client): + assert cerebras_client.api_key == "fake_api_key", "Config api_key should be correctly set" + + +# Test parameters +@pytest.mark.skipif(skip, reason=skip_reason) +def test_parsing_params(cerebras_client): + # All parameters + params = { + "model": "llama3.1-8b", + "max_tokens": 1000, + "seed": 42, + "stream": False, + "temperature": 1, + "top_p": 0.8, + } + expected_params = { + "model": "llama3.1-8b", + "max_tokens": 1000, + "seed": 42, + "stream": False, + "temperature": 1, + "top_p": 0.8, + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # Only model, others set as defaults + params = { + "model": "llama3.1-8b", + } + expected_params = { + "model": "llama3.1-8b", + "max_tokens": None, + "seed": None, + "stream": False, + "temperature": 1, + "top_p": None, + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # Incorrect types, defaults should be set, will show warnings but not trigger assertions + params = { + "model": "llama3.1-8b", + "max_tokens": "1000", + "seed": "42", + "stream": "False", + "temperature": "1", + "top_p": "0.8", + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # Values outside bounds, should warn and set to defaults + params = { + "model": "llama3.1-8b", + "temperature": 33123, + } + result = cerebras_client.parse_params(params) + assert result == expected_params + + # No model + params = { + "temperature": 1, + } + + with pytest.raises(AssertionError) as assertinfo: + result = cerebras_client.parse_params(params) + + assert "Please specify the 'model' in your config list entry to nominate the Cerebras model to use." in str( + assertinfo.value + ) + + +# Test cost calculation +@pytest.mark.skipif(skip, reason=skip_reason) +def test_cost_calculation(mock_response): + response = mock_response( + text="Example response", + choices=[{"message": "Test message 1"}], + usage={"prompt_tokens": 500, "completion_tokens": 300, "total_tokens": 800}, + cost=None, + model="llama3.1-70b", + ) + calculated_cost = calculate_cerebras_cost( + response.usage["prompt_tokens"], response.usage["completion_tokens"], response.model + ) + + # Convert cost per milliion to cost per token. + expected_cost = ( + response.usage["prompt_tokens"] * 0.6 / 1000000 + response.usage["completion_tokens"] * 0.6 / 1000000 + ) + + assert calculated_cost == expected_cost, f"Cost for this should be ${expected_cost} but got ${calculated_cost}" + + +# Test text generation +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.cerebras.CerebrasClient.create") +def test_create_response(mock_chat, cerebras_client): + # Mock CerebrasClient.chat response + mock_cerebras_response = MagicMock() + mock_cerebras_response.choices = [ + MagicMock(finish_reason="stop", message=MagicMock(content="Example Cerebras response", tool_calls=None)) + ] + mock_cerebras_response.id = "mock_cerebras_response_id" + mock_cerebras_response.model = "llama3.1-70b" + mock_cerebras_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20) # Example token usage + + mock_chat.return_value = mock_cerebras_response + + # Test parameters + params = { + "messages": [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "World"}], + "model": "llama3.1-70b", + } + + # Call the create method + response = cerebras_client.create(params) + + # Assertions to check if response is structured as expected + assert ( + response.choices[0].message.content == "Example Cerebras response" + ), "Response content should match expected output" + assert response.id == "mock_cerebras_response_id", "Response ID should match the mocked response ID" + assert response.model == "llama3.1-70b", "Response model should match the mocked response model" + assert response.usage.prompt_tokens == 10, "Response prompt tokens should match the mocked response usage" + assert response.usage.completion_tokens == 20, "Response completion tokens should match the mocked response usage" + + +# Test functions/tools +@pytest.mark.skipif(skip, reason=skip_reason) +@patch("autogen.oai.cerebras.CerebrasClient.create") +def test_create_response_with_tool_call(mock_chat, cerebras_client): + # Mock `cerebras_response = client.chat(**cerebras_params)` + mock_function = MagicMock(name="currency_calculator") + mock_function.name = "currency_calculator" + mock_function.arguments = '{"base_currency": "EUR", "quote_currency": "USD", "base_amount": 123.45}' + + mock_function_2 = MagicMock(name="get_weather") + mock_function_2.name = "get_weather" + mock_function_2.arguments = '{"location": "Chicago"}' + + mock_chat.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="tool_calls", + message=MagicMock( + content="Sample text about the functions", + tool_calls=[ + MagicMock(id="gdRdrvnHh", function=mock_function), + MagicMock(id="abRdrvnHh", function=mock_function_2), + ], + ), + ) + ], + id="mock_cerebras_response_id", + model="llama3.1-70b", + usage=MagicMock(prompt_tokens=10, completion_tokens=20), + ) + + # Construct parameters + converted_functions = [ + { + "type": "function", + "function": { + "description": "Currency exchange calculator.", + "name": "currency_calculator", + "parameters": { + "type": "object", + "properties": { + "base_amount": {"type": "number", "description": "Amount of currency in base_currency"}, + }, + "required": ["base_amount"], + }, + }, + } + ] + cerebras_messages = [ + {"role": "user", "content": "How much is 123.45 EUR in USD?"}, + {"role": "assistant", "content": "World"}, + ] + + # Call the create method + response = cerebras_client.create( + {"messages": cerebras_messages, "tools": converted_functions, "model": "llama3.1-70b"} + ) + + # Assertions to check if the functions and content are included in the response + assert response.choices[0].message.content == "Sample text about the functions" + assert response.choices[0].message.tool_calls[0].function.name == "currency_calculator" + assert response.choices[0].message.tool_calls[1].function.name == "get_weather" diff --git a/website/docs/topics/non-openai-models/cloud-cerebras.ipynb b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb new file mode 100644 index 000000000000..a8e1d3940f4b --- /dev/null +++ b/website/docs/topics/non-openai-models/cloud-cerebras.ipynb @@ -0,0 +1,505 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cerebras\n", + "\n", + "[Cerebras](https://cerebras.ai) has developed the world's largest and fastest AI processor, the Wafer-Scale Engine-3 (WSE-3). Notably, the CS-3 system can run large language models like Llama-3.1-8B and Llama-3.1-70B at extremely fast speeds, making it an ideal platform for demanding AI workloads.\n", + "\n", + "While it's technically possible to adapt AutoGen to work with Cerebras' API by updating the `base_url`, this approach may not fully account for minor differences in parameter support. Using this library will also allow for tracking of the API costs based on actual token usage.\n", + "\n", + "For more information about Cerebras Cloud, visit [cloud.cerebras.ai](https://cloud.cerebras.ai). Their API reference is available at [inference-docs.cerebras.ai](https://inference-docs.cerebras.ai)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Requirements\n", + "To use Cerebras with AutoGen, install the `pyautogen[cerebras]` package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install pyautogen[\"cerebras\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Getting Started\n", + "\n", + "Cerebras provides a number of models to use. See the list of [models here](https://inference-docs.cerebras.ai/introduction).\n", + "\n", + "See the sample `OAI_CONFIG_LIST` below showing how the Cerebras AI client class is used by specifying the `api_type` as `cerebras`.\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1-8b\",\n", + " \"api_key\": \"your Cerebras API Key goes here\",\n", + " \"api_type\": \"cerebras\"\n", + " },\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": \"your Cerebras API Key goes here\",\n", + " \"api_type\": \"cerebras\"\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Credentials\n", + "\n", + "Get an API Key from [cloud.cerebras.ai](https://cloud.cerebras.ai/) and add it to your environment variables:\n", + "\n", + "```\n", + "export CEREBRAS_API_KEY=\"your-api-key-here\"\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API parameters\n", + "\n", + "The following parameters can be added to your config for the Cerebras API. See [this link](https://inference-docs.cerebras.ai/api-reference/chat-completions) for further information on them and their default values.\n", + "\n", + "- max_tokens (null, integer >= 0)\n", + "- seed (number)\n", + "- stream (True or False)\n", + "- temperature (number 0..1.5)\n", + "- top_p (number)\n", + "\n", + "Example:\n", + "```python\n", + "[\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": \"your Cerebras API Key goes here\",\n", + " \"api_type\": \"cerebras\"\n", + " \"max_tokens\": 10000,\n", + " \"seed\": 1234,\n", + " \"stream\" True,\n", + " \"temperature\": 0.5,\n", + " \"top_p\": 0.2, # Note: It is recommended to set temperature or top_p but not both.\n", + " }\n", + "]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two-Agent Coding Example\n", + "\n", + "In this example, we run a two-agent chat with an AssistantAgent (primarily a coding agent) to generate code to count the number of prime numbers between 1 and 10,000 and then it will be executed.\n", + "\n", + "We'll use Meta's LLama-3.1-70B model which is suitable for coding." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from autogen.oai.cerebras import CerebrasClient, calculate_cerebras_cost\n", + "\n", + "config_list = [{\"model\": \"llama3.1-70b\", \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"), \"api_type\": \"cerebras\"}]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Importantly, we have tweaked the system message so that the model doesn't return the termination keyword, which we've changed to FINISH, with the code block." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from autogen.coding import LocalCommandLineCodeExecutor\n", + "\n", + "# Setting up the code executor\n", + "workdir = Path(\"coding\")\n", + "workdir.mkdir(exist_ok=True)\n", + "code_executor = LocalCommandLineCodeExecutor(work_dir=workdir)\n", + "\n", + "# Setting up the agents\n", + "\n", + "# The UserProxyAgent will execute the code that the AssistantAgent provides\n", + "user_proxy_agent = UserProxyAgent(\n", + " name=\"User\",\n", + " code_execution_config={\"executor\": code_executor},\n", + " is_termination_msg=lambda msg: \"FINISH\" in msg.get(\"content\"),\n", + ")\n", + "\n", + "system_message = \"\"\"You are a helpful AI assistant who writes code and the user executes it.\n", + "Solve tasks using your coding and language skills.\n", + "In the following cases, suggest python code (in a python coding block) for the user to execute.\n", + "Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.\n", + "When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.\n", + "Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.\n", + "If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.\n", + "When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.\n", + "IMPORTANT: Wait for the user to execute your code and then you can reply with the word \"FINISH\". DO NOT OUTPUT \"FINISH\" after your code block.\"\"\"\n", + "\n", + "# The AssistantAgent, using Cerebras AI's model, will take the coding request and return code\n", + "assistant_agent = AssistantAgent(\n", + " name=\"Cerebras Assistant\",\n", + " system_message=system_message,\n", + " llm_config={\"config_list\": config_list},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mUser\u001b[0m (to Cerebras Assistant):\n", + "\n", + "Provide code to count the number of prime numbers from 1 to 10000.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mCerebras Assistant\u001b[0m (to User):\n", + "\n", + "To count the number of prime numbers from 1 to 10000, we will utilize a simple algorithm that checks each number in the range to see if it is prime. A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.\n", + "\n", + "Here's how we can do it using a Python script:\n", + "\n", + "```python\n", + "def count_primes(n):\n", + " primes = 0\n", + " for possiblePrime in range(2, n + 1):\n", + " # Assume number is prime until shown it is not. \n", + " isPrime = True\n", + " for num in range(2, int(possiblePrime ** 0.5) + 1):\n", + " if possiblePrime % num == 0:\n", + " isPrime = False\n", + " break\n", + " if isPrime:\n", + " primes += 1\n", + " return primes\n", + "\n", + "# Counting prime numbers from 1 to 10000\n", + "count = count_primes(10000)\n", + "print(count)\n", + "```\n", + "\n", + "Please execute this code. I will respond with \"FINISH\" after you provide the result.\n", + "\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Replying as User. Provide feedback to Cerebras Assistant. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[31m\n", + ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n" + ] + } + ], + "source": [ + "# Start the chat, with the UserProxyAgent asking the AssistantAgent the message\n", + "chat_result = user_proxy_agent.initiate_chat(\n", + " assistant_agent,\n", + " message=\"Provide code to count the number of prime numbers from 1 to 10000.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tool Call Example\n", + "\n", + "In this example, instead of writing code, we will show how Meta's Llama-3.1-70B model can perform parallel tool calling, where it recommends calling more than one tool at a time.\n", + "\n", + "We'll use a simple travel agent assistant program where we have a couple of tools for weather and currency conversion.\n", + "\n", + "We start by importing libraries and setting up our configuration to use Llama-3.1-70B and the `cerebras` client class." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from typing import Literal\n", + "\n", + "from typing_extensions import Annotated\n", + "\n", + "import autogen\n", + "\n", + "config_list = [\n", + " {\n", + " \"model\": \"llama3.1-70b\",\n", + " \"api_key\": os.environ.get(\"CEREBRAS_API_KEY\"),\n", + " \"api_type\": \"cerebras\",\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create our two agents." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the agent for tool calling\n", + "chatbot = autogen.AssistantAgent(\n", + " name=\"chatbot\",\n", + " system_message=\"\"\"\n", + " For currency exchange and weather forecasting tasks,\n", + " only use the functions you have been provided with.\n", + " When you summarize, make sure you've considered ALL previous instructions.\n", + " Output 'HAVE FUN!' when an answer has been provided.\n", + " \"\"\",\n", + " llm_config={\"config_list\": config_list},\n", + ")\n", + "\n", + "# Note that we have changed the termination string to be \"HAVE FUN!\"\n", + "user_proxy = autogen.UserProxyAgent(\n", + " name=\"user_proxy\",\n", + " is_termination_msg=lambda x: x.get(\"content\", \"\") and \"HAVE FUN!\" in x.get(\"content\", \"\"),\n", + " human_input_mode=\"NEVER\",\n", + " max_consecutive_auto_reply=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the two functions, annotating them so that those descriptions can be passed through to the LLM.\n", + "\n", + "We associate them with the agents using `register_for_execution` for the user_proxy so it can execute the function and `register_for_llm` for the chatbot (powered by the LLM) so it can pass the function definitions to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "# Currency Exchange function\n", + "\n", + "CurrencySymbol = Literal[\"USD\", \"EUR\"]\n", + "\n", + "# Define our function that we expect to call\n", + "\n", + "\n", + "def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:\n", + " if base_currency == quote_currency:\n", + " return 1.0\n", + " elif base_currency == \"USD\" and quote_currency == \"EUR\":\n", + " return 1 / 1.1\n", + " elif base_currency == \"EUR\" and quote_currency == \"USD\":\n", + " return 1.1\n", + " else:\n", + " raise ValueError(f\"Unknown currencies {base_currency}, {quote_currency}\")\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Currency exchange calculator.\")\n", + "def currency_calculator(\n", + " base_amount: Annotated[float, \"Amount of currency in base_currency\"],\n", + " base_currency: Annotated[CurrencySymbol, \"Base currency\"] = \"USD\",\n", + " quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n", + ") -> str:\n", + " quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n", + " return f\"{format(quote_amount, '.2f')} {quote_currency}\"\n", + "\n", + "\n", + "# Weather function\n", + "\n", + "\n", + "# Example function to make available to model\n", + "def get_current_weather(location, unit=\"fahrenheit\"):\n", + " \"\"\"Get the weather for some location\"\"\"\n", + " if \"chicago\" in location.lower():\n", + " return json.dumps({\"location\": \"Chicago\", \"temperature\": \"13\", \"unit\": unit})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"55\", \"unit\": unit})\n", + " elif \"new york\" in location.lower():\n", + " return json.dumps({\"location\": \"New York\", \"temperature\": \"11\", \"unit\": unit})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", + "\n", + "\n", + "# Register the function with the agent\n", + "\n", + "\n", + "@user_proxy.register_for_execution()\n", + "@chatbot.register_for_llm(description=\"Weather forecast for US cities.\")\n", + "def weather_forecast(\n", + " location: Annotated[str, \"City name\"],\n", + ") -> str:\n", + " weather_details = get_current_weather(location=location)\n", + " weather = json.loads(weather_details)\n", + " return f\"{weather['location']} will be {weather['temperature']} degrees {weather['unit']}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We pass through our customer's message and run the chat.\n", + "\n", + "Finally, we ask the LLM to summarise the chat and print that out." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "\u001b[32m***** Suggested tool call (210f6ac6d): weather_forecast *****\u001b[0m\n", + "Arguments: \n", + "{\"location\": \"New York\"}\n", + "\u001b[32m*************************************************************\u001b[0m\n", + "\u001b[32m***** Suggested tool call (3c00ac7d5): currency_calculator *****\u001b[0m\n", + "Arguments: \n", + "{\"base_amount\": 123.45, \"base_currency\": \"EUR\", \"quote_currency\": \"USD\"}\n", + "\u001b[32m****************************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION weather_forecast...\u001b[0m\n", + "\u001b[35m\n", + ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (210f6ac6d) *****\u001b[0m\n", + "New York will be 11 degrees fahrenheit\n", + "\u001b[32m**************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", + "\n", + "\u001b[32m***** Response from calling tool (3c00ac7d5) *****\u001b[0m\n", + "135.80 USD\n", + "\u001b[32m**************************************************\u001b[0m\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", + "\n", + "New York will be 11 degrees fahrenheit.\n", + "123.45 EUR is equivalent to 135.80 USD.\n", + " \n", + "For a great holiday, explore the Statue of Liberty, take a walk through Central Park, or visit one of the many world-class museums. Also, you'll find great food ranging from bagels to fine dining experiences. HAVE FUN!\n", + "\n", + "--------------------------------------------------------------------------------\n", + "LLM SUMMARY: New York will be 11 degrees fahrenheit. 123.45 EUR is equivalent to 135.80 USD. Explore the Statue of Liberty, walk through Central Park, or visit one of the many world-class museums for a great holiday in New York.\n", + "\n", + "Duration: 73.97937774658203ms\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "start_time = time.time()\n", + "\n", + "# start the conversation\n", + "res = user_proxy.initiate_chat(\n", + " chatbot,\n", + " message=\"What's the weather in New York and can you tell me how much is 123.45 EUR in USD so I can spend it on my holiday? Throw a few holiday tips in as well.\",\n", + " summary_method=\"reflection_with_llm\",\n", + ")\n", + "\n", + "end_time = time.time()\n", + "\n", + "print(f\"LLM SUMMARY: {res.summary['content']}\\n\\nDuration: {(end_time - start_time) * 1000}ms\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see that the Cerebras Wafer-Scale Engine-3 (WSE-3) completed the query in 74ms -- faster than the blink of an eye!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}