Skip to content

Commit

Permalink
Add Groq Support (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
pushpak1300 authored Oct 27, 2024
1 parent 12dcbd1 commit e4d32ff
Show file tree
Hide file tree
Showing 16 changed files with 753 additions and 1 deletion.
4 changes: 4 additions & 0 deletions config/prism.php
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,9 @@
'api_key' => env('MISTRAL_API_KEY', ''),
'url' => env('MISTRAL_URL', 'https://api.mistral.ai/v1'),
],
'groq' => [
'api_key' => env('GROQ_API_KEY', ''),
'url' => env('GROQ_URL', 'https://api.groq.com/openai/v1'),
],
],
];
1 change: 1 addition & 0 deletions src/Enums/Provider.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ enum Provider: string
case Ollama = 'ollama';
case OpenAI = 'openai';
case Mistral = 'mistral';
case Groq = 'groq';
}
12 changes: 12 additions & 0 deletions src/PrismManager.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use EchoLabs\Prism\Contracts\Provider;
use EchoLabs\Prism\Enums\Provider as ProviderEnum;
use EchoLabs\Prism\Providers\Anthropic\Anthropic;
use EchoLabs\Prism\Providers\Groq\Groq;
use EchoLabs\Prism\Providers\Mistral\Mistral;
use EchoLabs\Prism\Providers\Ollama\Ollama;
use EchoLabs\Prism\Providers\OpenAI\OpenAI;
Expand Down Expand Up @@ -135,4 +136,15 @@ protected function getConfig(string $name): ?array

return ['driver' => 'null'];
}

/**
* @param array<string, string> $config
*/
protected function createGroqProvider(array $config): Groq
{
return new Groq(
url: $config['url'],
apiKey: $config['api_key'],
);
}
}
55 changes: 55 additions & 0 deletions src/Providers/Groq/Client.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\Groq;

use Illuminate\Http\Client\PendingRequest;
use Illuminate\Http\Client\Response;
use Illuminate\Support\Facades\Http;

class Client
{
protected PendingRequest $client;

/**
* @param array<string, mixed> $options
*/
public function __construct(
public readonly string $url,
public readonly string $apiKey,
public readonly array $options = [],
) {
$this->client = Http::withHeaders(array_filter([
'Authorization' => sprintf('Bearer %s', $this->apiKey),
]))
->withOptions($this->options)
->baseUrl($this->url);
}

/**
* @param array<int, mixed> $messages
* @param array<int, mixed>|null $tools
*/
public function messages(
string $model,
array $messages,
?int $maxTokens,
int|float|null $temperature,
int|float|null $topP,
?array $tools,
): Response {
return $this->client->post(
'chat/completions',
array_merge([
'model' => $model,
'messages' => $messages,
'max_tokens' => $maxTokens ?? 2048,
], array_filter([
'temperature' => $temperature,
'top_p' => $topP,
'tools' => $tools,
]))
);
}
}
105 changes: 105 additions & 0 deletions src/Providers/Groq/Groq.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\Groq;

use EchoLabs\Prism\Contracts\Provider;
use EchoLabs\Prism\Enums\FinishReason;
use EchoLabs\Prism\Exceptions\PrismException;
use EchoLabs\Prism\Providers\ProviderResponse;
use EchoLabs\Prism\Requests\TextRequest;
use EchoLabs\Prism\ValueObjects\ToolCall;
use EchoLabs\Prism\ValueObjects\Usage;
use Throwable;

class Groq implements Provider
{
public function __construct(
public readonly string $url,
public readonly string $apiKey,
) {}

#[\Override]
public function text(TextRequest $request): ProviderResponse
{
try {
$response = $this
->client($request->clientOptions)
->messages(
model: $request->model,
messages: (new MessageMap(
$request->messages,
$request->systemPrompt ?? '',
))(),
maxTokens: $request->maxTokens,
temperature: $request->temperature,
topP: $request->topP,
tools: Tool::map($request->tools),
);
} catch (Throwable $e) {
throw PrismException::providerRequestError($request->model, $e);
}

$data = $response->json();

if (data_get($data, 'message') || ! $data) {
throw PrismException::providerResponseError(vsprintf(
'Mistral Error: %s',
[
data_get($data, 'message', 'unknown'),
]
));
}

return new ProviderResponse(
text: data_get($data, 'choices.0.message.content') ?? '',
toolCalls: $this->mapToolCalls(data_get($data, 'choices.0.message.tool_calls', []) ?? []),
usage: new Usage(
data_get($data, 'usage.prompt_tokens'),
data_get($data, 'usage.completion_tokens'),
),
finishReason: $this->mapFinishReason(data_get($data, 'choices.0.finish_reason', '')),
response: [
'id' => data_get($data, 'id'),
'model' => data_get($data, 'model'),
]
);
}

/**
* @param array<int, array<string, mixed>> $toolCalls
* @return array<int, ToolCall>
*/
protected function mapToolCalls(array $toolCalls): array
{
return array_map(fn (array $toolCall): ToolCall => new ToolCall(
id: data_get($toolCall, 'id'),
name: data_get($toolCall, 'function.name'),
arguments: data_get($toolCall, 'function.arguments'),
), $toolCalls);
}

/**
* @param array<string, mixed> $options
*/
protected function client(array $options = []): Client
{
return new Client(
apiKey: $this->apiKey,
url: $this->url,
options: $options,
);
}

protected function mapFinishReason(string $stopReason): FinishReason
{
return match ($stopReason) {
'stop', => FinishReason::Stop,
'tool_calls' => FinishReason::ToolCalls,
'length' => FinishReason::Length,
'content_filter' => FinishReason::ContentFilter,
default => FinishReason::Unknown,
};
}
}
117 changes: 117 additions & 0 deletions src/Providers/Groq/MessageMap.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\Groq;

use EchoLabs\Prism\Contracts\Message;
use EchoLabs\Prism\ValueObjects\Messages\AssistantMessage;
use EchoLabs\Prism\ValueObjects\Messages\Support\Image;
use EchoLabs\Prism\ValueObjects\Messages\SystemMessage;
use EchoLabs\Prism\ValueObjects\Messages\ToolResultMessage;
use EchoLabs\Prism\ValueObjects\Messages\UserMessage;
use EchoLabs\Prism\ValueObjects\ToolCall;
use Exception;
use Illuminate\Support\Str;

class MessageMap
{
/** @var array<int, mixed> */
protected $mappedMessages = [];

/**
* @param array<int, Message> $messages
*/
public function __construct(
protected array $messages,
protected string $systemPrompt
) {
if ($systemPrompt !== '' && $systemPrompt !== '0') {
$this->messages = array_merge(
[new SystemMessage($systemPrompt)],
$this->messages
);
}
}

/**
* @return array<int, mixed>
*/
public function __invoke(): array
{
array_map(
fn (Message $message) => $this->mapMessage($message),
$this->messages
);

return $this->mappedMessages;
}

public function mapMessage(Message $message): void
{
match ($message::class) {
UserMessage::class => $this->mapUserMessage($message),
AssistantMessage::class => $this->mapAssistantMessage($message),
ToolResultMessage::class => $this->mapToolResultMessage($message),
SystemMessage::class => $this->mapSystemMessage($message),
default => throw new Exception('Could not map message type '.$message::class),
};
}

protected function mapSystemMessage(SystemMessage $message): void
{
$this->mappedMessages[] = [
'role' => 'system',
'content' => $message->content,
];
}

protected function mapToolResultMessage(ToolResultMessage $message): void
{
foreach ($message->toolResults as $toolResult) {
$this->mappedMessages[] = [
'role' => 'tool',
'tool_call_id' => $toolResult->toolCallId,
'content' => $toolResult->result,
];
}
}

protected function mapUserMessage(UserMessage $message): void
{
$imageParts = array_map(fn (Image $part): array => [
'type' => 'image_url',
'image_url' => [
'url' => Str::isUrl($part->image)
? $part->image
: sprintf('data:%s;base64,%s', $part->mimeType ?? 'image/jpeg', $part->image),
],
], $message->images());

$this->mappedMessages[] = [
'role' => 'user',
'content' => [
['type' => 'text', 'text' => $message->text()],
...$imageParts,
],
];
}

protected function mapAssistantMessage(AssistantMessage $message): void
{
$toolCalls = array_map(fn (ToolCall $toolCall): array => [
'id' => $toolCall->id,
'type' => 'function',
'function' => [
'name' => $toolCall->name,
'arguments' => json_encode($toolCall->arguments()),
],
], $message->toolCalls);

$this->mappedMessages[] = array_filter([
'role' => 'assistant',
'content' => $message->content,
'tool_calls' => $toolCalls,
]);
}
}
34 changes: 34 additions & 0 deletions src/Providers/Groq/Tool.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<?php

declare(strict_types=1);

namespace EchoLabs\Prism\Providers\Groq;

use EchoLabs\Prism\Providers\ProviderTool;
use EchoLabs\Prism\Tool as PrismTool;

class Tool extends ProviderTool
{
#[\Override]
public static function toArray(PrismTool $tool): array
{
return [
'type' => 'function',
'function' => [
'name' => $tool->name(),
'description' => $tool->description(),
'parameters' => [
'type' => 'object',
'properties' => collect($tool->parameters())
->keyBy('name')
->map(fn (array $field): array => [
'description' => $field['description'],
'type' => $field['type'],
])
->toArray(),
'required' => $tool->requiredParameters(),
],
],
];
}
}
30 changes: 30 additions & 0 deletions tests/Fixtures/groq/generate-text-with-a-prompt-1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"id": "chatcmpl-ea37c181-ed35-4bd4-af20-c1fcf203e0d8",
"object": "chat.completion",
"created": 1730033815,
"model": "llama3-8b-8192",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I am LLaMA, an AI assistant developed by Meta AI."
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"queue_time": 0.016606609,
"prompt_tokens": 13,
"prompt_time": 0.002070851,
"completion_tokens": 208,
"completion_time": 0.173333333,
"total_tokens": 221,
"total_time": 0.175404184
},
"system_fingerprint": "fp_af05557ca2",
"x_groq": {
"id": "req_01jb70t402ff28vnfp2dbp5eyx"
}
}
Loading

0 comments on commit e4d32ff

Please sign in to comment.