Skip to content

Commit

Permalink
Add support for Cohere
Browse files Browse the repository at this point in the history
  • Loading branch information
kharvd committed May 29, 2024
1 parent 7a2a728 commit 37d7d16
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 120
ignore = E402, W503
ignore = E203, E402, W503
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from attr import dataclass
import platform
from typing import Any, Dict, Iterator, Optional, TypedDict, List
from gptcli.cohere import CohereCompletionProvider

from gptcli.completion import (
CompletionEvent,
Expand Down Expand Up @@ -78,6 +79,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
return LLaMACompletionProvider()
elif model.startswith("chat-bison"):
return GoogleCompletionProvider()
elif model.startswith("command") or model.startswith("c4ai"):
return CohereCompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
140 changes: 140 additions & 0 deletions gptcli/cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import cohere
from typing import Iterator, List

from gptcli.completion import (
CompletionEvent,
CompletionProvider,
Message,
CompletionError,
BadRequestError,
MessageDeltaEvent,
Pricing,
UsageEvent,
)

api_key = os.environ.get("COHERE_API_KEY")

ROLE_MAP = {
"system": "SYSTEM",
"user": "USER",
"assistant": "CHATBOT",
}


def map_message(message: Message) -> cohere.Message:
if message["role"] == "system":
return cohere.Message_System(message=message["content"])
elif message["role"] == "user":
return cohere.Message_User(message=message["content"])
elif message["role"] == "assistant":
return cohere.Message_Chatbot(message=message["content"])
else:
raise ValueError(f"Unknown message role: {message['role']}")


class CohereCompletionProvider(CompletionProvider):
def __init__(self):
self.client = cohere.Client(api_key=api_key)

def complete(
self, messages: List[Message], args: dict, stream: bool = False
) -> Iterator[CompletionEvent]:
kwargs = {}
if "temperature" in args:
kwargs["temperature"] = args["temperature"]
if "top_p" in args:
kwargs["p"] = args["top_p"]

model = args["model"]
if model.startswith("oai-compat:"):
model = model[len("oai-compat:") :]

if messages[0]["role"] == "system":
kwargs["preamble"] = messages[0]["content"]
messages = messages[1:]

message = messages[-1]
assert message["role"] == "user", "Last message must be user message"

chat_history = [map_message(m) for m in messages[:-1]]

try:
if stream:
response_iter = self.client.chat_stream(
chat_history=chat_history,
message=message["content"],
model=model,
**kwargs,
)

for response in response_iter:
if response.event_type == "text-generation":
yield MessageDeltaEvent(response.text)

if (
response.event_type == "stream-end"
and response.response.meta
and response.response.meta.tokens
and (pricing := COHERE_PRICING.get(args["model"]))
):
input_tokens = int(
response.response.meta.tokens.input_tokens or 0
)
output_tokens = int(
response.response.meta.tokens.output_tokens or 0
)
total_tokens = input_tokens + output_tokens

yield UsageEvent.with_pricing(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
pricing=pricing,
)

else:
response = self.client.chat(
chat_history=chat_history,
message=message["content"],
model=model,
**kwargs,
)
yield MessageDeltaEvent(response.text)

if (
response.meta
and response.meta.tokens
and (pricing := COHERE_PRICING.get(args["model"]))
):
input_tokens = int(response.meta.tokens.input_tokens or 0)
output_tokens = int(response.meta.tokens.output_tokens or 0)
total_tokens = input_tokens + output_tokens

yield UsageEvent.with_pricing(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
pricing=pricing,
)

except cohere.BadRequestError as e:
raise BadRequestError(e.body) from e
except (
cohere.TooManyRequestsError,
cohere.InternalServerError,
cohere.core.api_error.ApiError, # type: ignore
) as e:
raise CompletionError(e.body) from e


COHERE_PRICING: dict[str, Pricing] = {
"command-r": {
"prompt": 0.5 / 1_000_000,
"response": 1.5 / 1_000_000,
},
"command-r-plus": {
"prompt": 3.0 / 1_000_000,
"response": 15.0 / 1_000_000,
},
}
1 change: 1 addition & 0 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class GptCliConfig:
openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL")
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY")
log_file: Optional[str] = None
log_level: str = "INFO"
assistants: Dict[str, AssistantConfig] = {}
Expand Down
4 changes: 4 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import datetime
import google.generativeai as genai
import gptcli.anthropic
import gptcli.cohere
from gptcli.assistant import (
Assistant,
DEFAULT_ASSISTANTS,
Expand Down Expand Up @@ -189,6 +190,9 @@ def main():
if config.google_api_key:
genai.configure(api_key=config.google_api_key)

if config.cohere_api_key:
gptcli.cohere.api_key = config.cohere_api_key

if config.llama_models is not None:
init_llama_models(config.llama_models)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"anthropic==0.25.9",
"attrs==23.2.0",
"black==24.4.2",
"cohere==5.5.3",
"mistralai==0.1.8",
"google-generativeai==0.1.0",
"openai==1.30.1",
Expand Down

0 comments on commit 37d7d16

Please sign in to comment.