-
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
12dcbd1
commit e4d32ff
Showing
16 changed files
with
753 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace EchoLabs\Prism\Providers\Groq; | ||
|
||
use Illuminate\Http\Client\PendingRequest; | ||
use Illuminate\Http\Client\Response; | ||
use Illuminate\Support\Facades\Http; | ||
|
||
class Client | ||
{ | ||
protected PendingRequest $client; | ||
|
||
/** | ||
* @param array<string, mixed> $options | ||
*/ | ||
public function __construct( | ||
public readonly string $url, | ||
public readonly string $apiKey, | ||
public readonly array $options = [], | ||
) { | ||
$this->client = Http::withHeaders(array_filter([ | ||
'Authorization' => sprintf('Bearer %s', $this->apiKey), | ||
])) | ||
->withOptions($this->options) | ||
->baseUrl($this->url); | ||
} | ||
|
||
/** | ||
* @param array<int, mixed> $messages | ||
* @param array<int, mixed>|null $tools | ||
*/ | ||
public function messages( | ||
string $model, | ||
array $messages, | ||
?int $maxTokens, | ||
int|float|null $temperature, | ||
int|float|null $topP, | ||
?array $tools, | ||
): Response { | ||
return $this->client->post( | ||
'chat/completions', | ||
array_merge([ | ||
'model' => $model, | ||
'messages' => $messages, | ||
'max_tokens' => $maxTokens ?? 2048, | ||
], array_filter([ | ||
'temperature' => $temperature, | ||
'top_p' => $topP, | ||
'tools' => $tools, | ||
])) | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace EchoLabs\Prism\Providers\Groq; | ||
|
||
use EchoLabs\Prism\Contracts\Provider; | ||
use EchoLabs\Prism\Enums\FinishReason; | ||
use EchoLabs\Prism\Exceptions\PrismException; | ||
use EchoLabs\Prism\Providers\ProviderResponse; | ||
use EchoLabs\Prism\Requests\TextRequest; | ||
use EchoLabs\Prism\ValueObjects\ToolCall; | ||
use EchoLabs\Prism\ValueObjects\Usage; | ||
use Throwable; | ||
|
||
class Groq implements Provider | ||
{ | ||
public function __construct( | ||
public readonly string $url, | ||
public readonly string $apiKey, | ||
) {} | ||
|
||
#[\Override] | ||
public function text(TextRequest $request): ProviderResponse | ||
{ | ||
try { | ||
$response = $this | ||
->client($request->clientOptions) | ||
->messages( | ||
model: $request->model, | ||
messages: (new MessageMap( | ||
$request->messages, | ||
$request->systemPrompt ?? '', | ||
))(), | ||
maxTokens: $request->maxTokens, | ||
temperature: $request->temperature, | ||
topP: $request->topP, | ||
tools: Tool::map($request->tools), | ||
); | ||
} catch (Throwable $e) { | ||
throw PrismException::providerRequestError($request->model, $e); | ||
} | ||
|
||
$data = $response->json(); | ||
|
||
if (data_get($data, 'message') || ! $data) { | ||
throw PrismException::providerResponseError(vsprintf( | ||
'Mistral Error: %s', | ||
[ | ||
data_get($data, 'message', 'unknown'), | ||
] | ||
)); | ||
} | ||
|
||
return new ProviderResponse( | ||
text: data_get($data, 'choices.0.message.content') ?? '', | ||
toolCalls: $this->mapToolCalls(data_get($data, 'choices.0.message.tool_calls', []) ?? []), | ||
usage: new Usage( | ||
data_get($data, 'usage.prompt_tokens'), | ||
data_get($data, 'usage.completion_tokens'), | ||
), | ||
finishReason: $this->mapFinishReason(data_get($data, 'choices.0.finish_reason', '')), | ||
response: [ | ||
'id' => data_get($data, 'id'), | ||
'model' => data_get($data, 'model'), | ||
] | ||
); | ||
} | ||
|
||
/** | ||
* @param array<int, array<string, mixed>> $toolCalls | ||
* @return array<int, ToolCall> | ||
*/ | ||
protected function mapToolCalls(array $toolCalls): array | ||
{ | ||
return array_map(fn (array $toolCall): ToolCall => new ToolCall( | ||
id: data_get($toolCall, 'id'), | ||
name: data_get($toolCall, 'function.name'), | ||
arguments: data_get($toolCall, 'function.arguments'), | ||
), $toolCalls); | ||
} | ||
|
||
/** | ||
* @param array<string, mixed> $options | ||
*/ | ||
protected function client(array $options = []): Client | ||
{ | ||
return new Client( | ||
apiKey: $this->apiKey, | ||
url: $this->url, | ||
options: $options, | ||
); | ||
} | ||
|
||
protected function mapFinishReason(string $stopReason): FinishReason | ||
{ | ||
return match ($stopReason) { | ||
'stop', => FinishReason::Stop, | ||
'tool_calls' => FinishReason::ToolCalls, | ||
'length' => FinishReason::Length, | ||
'content_filter' => FinishReason::ContentFilter, | ||
default => FinishReason::Unknown, | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace EchoLabs\Prism\Providers\Groq; | ||
|
||
use EchoLabs\Prism\Contracts\Message; | ||
use EchoLabs\Prism\ValueObjects\Messages\AssistantMessage; | ||
use EchoLabs\Prism\ValueObjects\Messages\Support\Image; | ||
use EchoLabs\Prism\ValueObjects\Messages\SystemMessage; | ||
use EchoLabs\Prism\ValueObjects\Messages\ToolResultMessage; | ||
use EchoLabs\Prism\ValueObjects\Messages\UserMessage; | ||
use EchoLabs\Prism\ValueObjects\ToolCall; | ||
use Exception; | ||
use Illuminate\Support\Str; | ||
|
||
class MessageMap | ||
{ | ||
/** @var array<int, mixed> */ | ||
protected $mappedMessages = []; | ||
|
||
/** | ||
* @param array<int, Message> $messages | ||
*/ | ||
public function __construct( | ||
protected array $messages, | ||
protected string $systemPrompt | ||
) { | ||
if ($systemPrompt !== '' && $systemPrompt !== '0') { | ||
$this->messages = array_merge( | ||
[new SystemMessage($systemPrompt)], | ||
$this->messages | ||
); | ||
} | ||
} | ||
|
||
/** | ||
* @return array<int, mixed> | ||
*/ | ||
public function __invoke(): array | ||
{ | ||
array_map( | ||
fn (Message $message) => $this->mapMessage($message), | ||
$this->messages | ||
); | ||
|
||
return $this->mappedMessages; | ||
} | ||
|
||
public function mapMessage(Message $message): void | ||
{ | ||
match ($message::class) { | ||
UserMessage::class => $this->mapUserMessage($message), | ||
AssistantMessage::class => $this->mapAssistantMessage($message), | ||
ToolResultMessage::class => $this->mapToolResultMessage($message), | ||
SystemMessage::class => $this->mapSystemMessage($message), | ||
default => throw new Exception('Could not map message type '.$message::class), | ||
}; | ||
} | ||
|
||
protected function mapSystemMessage(SystemMessage $message): void | ||
{ | ||
$this->mappedMessages[] = [ | ||
'role' => 'system', | ||
'content' => $message->content, | ||
]; | ||
} | ||
|
||
protected function mapToolResultMessage(ToolResultMessage $message): void | ||
{ | ||
foreach ($message->toolResults as $toolResult) { | ||
$this->mappedMessages[] = [ | ||
'role' => 'tool', | ||
'tool_call_id' => $toolResult->toolCallId, | ||
'content' => $toolResult->result, | ||
]; | ||
} | ||
} | ||
|
||
protected function mapUserMessage(UserMessage $message): void | ||
{ | ||
$imageParts = array_map(fn (Image $part): array => [ | ||
'type' => 'image_url', | ||
'image_url' => [ | ||
'url' => Str::isUrl($part->image) | ||
? $part->image | ||
: sprintf('data:%s;base64,%s', $part->mimeType ?? 'image/jpeg', $part->image), | ||
], | ||
], $message->images()); | ||
|
||
$this->mappedMessages[] = [ | ||
'role' => 'user', | ||
'content' => [ | ||
['type' => 'text', 'text' => $message->text()], | ||
...$imageParts, | ||
], | ||
]; | ||
} | ||
|
||
protected function mapAssistantMessage(AssistantMessage $message): void | ||
{ | ||
$toolCalls = array_map(fn (ToolCall $toolCall): array => [ | ||
'id' => $toolCall->id, | ||
'type' => 'function', | ||
'function' => [ | ||
'name' => $toolCall->name, | ||
'arguments' => json_encode($toolCall->arguments()), | ||
], | ||
], $message->toolCalls); | ||
|
||
$this->mappedMessages[] = array_filter([ | ||
'role' => 'assistant', | ||
'content' => $message->content, | ||
'tool_calls' => $toolCalls, | ||
]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
<?php | ||
|
||
declare(strict_types=1); | ||
|
||
namespace EchoLabs\Prism\Providers\Groq; | ||
|
||
use EchoLabs\Prism\Providers\ProviderTool; | ||
use EchoLabs\Prism\Tool as PrismTool; | ||
|
||
class Tool extends ProviderTool | ||
{ | ||
#[\Override] | ||
public static function toArray(PrismTool $tool): array | ||
{ | ||
return [ | ||
'type' => 'function', | ||
'function' => [ | ||
'name' => $tool->name(), | ||
'description' => $tool->description(), | ||
'parameters' => [ | ||
'type' => 'object', | ||
'properties' => collect($tool->parameters()) | ||
->keyBy('name') | ||
->map(fn (array $field): array => [ | ||
'description' => $field['description'], | ||
'type' => $field['type'], | ||
]) | ||
->toArray(), | ||
'required' => $tool->requiredParameters(), | ||
], | ||
], | ||
]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
{ | ||
"id": "chatcmpl-ea37c181-ed35-4bd4-af20-c1fcf203e0d8", | ||
"object": "chat.completion", | ||
"created": 1730033815, | ||
"model": "llama3-8b-8192", | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "I am LLaMA, an AI assistant developed by Meta AI." | ||
}, | ||
"logprobs": null, | ||
"finish_reason": "stop" | ||
} | ||
], | ||
"usage": { | ||
"queue_time": 0.016606609, | ||
"prompt_tokens": 13, | ||
"prompt_time": 0.002070851, | ||
"completion_tokens": 208, | ||
"completion_time": 0.173333333, | ||
"total_tokens": 221, | ||
"total_time": 0.175404184 | ||
}, | ||
"system_fingerprint": "fp_af05557ca2", | ||
"x_groq": { | ||
"id": "req_01jb70t402ff28vnfp2dbp5eyx" | ||
} | ||
} |
Oops, something went wrong.