Skip to content

Commit

Permalink
Merge pull request #9 from GeraudBourdin/main
Browse files Browse the repository at this point in the history
Update tool_choice option
  • Loading branch information
GeraudBourdin authored Mar 5, 2024
2 parents efc0afb + 0e0b3d8 commit 3396ecb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
43 changes: 30 additions & 13 deletions examples/function_calling.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

// export MISTRAL_API_KEY=your_api_key
$apiKey = getenv('MISTRAL_API_KEY');
$client = new MistralClient($apiKey);
$client = new MistralClient(apiKey: $apiKey);

// Assuming we have the following data
$data = [
Expand Down Expand Up @@ -88,21 +88,26 @@
];

// Create the tools definition
$transactionIdParam = new Parameter(type: Parameter::STRING_TYPE, name: 'transactionId', description: 'The transaction id.', required: true);
$transactionIdParam = new Parameter(
type: Parameter::STRING_TYPE,
name: 'transactionId',
description: 'The transaction id.',
required: true
);

$retrievePaymentStatusFunction = new FunctionTool(
'retrievePaymentStatus',
'Get payment status of a transaction id',
[
name: 'retrievePaymentStatus',
description: 'Get payment status of a transaction id',
parameters: [
$transactionIdParam
]
);


$retrievePaymentDateFunction = new FunctionTool(
'retrievePaymentDate',
'Get payment date of a transaction id',
[
name: 'retrievePaymentDate',
description: 'Get payment date of a transaction id',
parameters: [
$transactionIdParam
]
);
Expand Down Expand Up @@ -156,14 +161,15 @@
//]

$messages = new Messages();
$messages->addUserMessage('What\'s the status of my transaction?');
$messages->addUserMessage(content: "What's the status of my transaction?");

try {
$chatResponse = $client->chat(
messages: $messages,
params: [
'model' => $model,
'tools' => $tools
'tools' => $tools,
'tool_choice' => MistralClient::TOOL_CHOICE_AUTO
]
);
} catch (MistralClientException $e) {
Expand All @@ -183,12 +189,19 @@


// Push response to history
$messages->addAssistantMessage($chatResponse->getMessage());
$messages->addAssistantMessage(content: $chatResponse->getMessage());
// Add customer response
$messages->addUserMessage('My transaction ID is T1001.');
$messages->addUserMessage(content: 'My transaction ID is T1001.');

try {
$chatResponse = $client->chat(messages: $messages, params: ['model' => $model, 'tools' => $tools]);
$chatResponse = $client->chat(
messages: $messages,
params: [
'model' => $model,
'tools' => $tools,
'tool_choice' => MistralClient::TOOL_CHOICE_AUTO
]
);
} catch (MistralClientException $e) {
echo $e->getMessage();
exit(1);
Expand All @@ -202,3 +215,7 @@
$functionResult = $namesToFunctions[$functionName]($functionParams);

print_r($functionResult);
// Array
// (
// [status] => Paid
// )
7 changes: 7 additions & 0 deletions src/MistralClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
class MistralClient
{
const string DEFAULT_MODEL = 'open-mistral-7b';
const string TOOL_CHOICE_ANY = 'any';
const string TOOL_CHOICE_AUTO = 'auto';
const string TOOL_CHOICE_NONE = 'none';

const array RETRY_STATUS_CODES = [429, 500 => GenericRetryStrategy::IDEMPOTENT_METHODS, 502, 503, 504 => GenericRetryStrategy::IDEMPOTENT_METHODS];
protected const string END_OF_STREAM = "[DONE]";
Expand Down Expand Up @@ -142,6 +145,10 @@ protected function makeChatCompletionRequest(Messages $messages, array $params,
$return['tools'] = $params['tools'];
}

if (isset($return['tools']) && isset($params['tool_choice'])) {
$return['tool_choice'] = $params['tool_choice'];
}

return $return;
}

Expand Down

0 comments on commit 3396ecb

Please sign in to comment.