Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic provider config #93

Merged
merged 6 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/core-concepts/structured-output.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ Structured output supports all the same options as text generation, including:
- Message history
- Tools and function calling
- System prompts
- withClientOptions
- withClientRetry
- usingProviderConfig

See the [Text Generation](./text-generation.md) documentation for details on these common settings.

Expand Down
3 changes: 3 additions & 0 deletions docs/core-concepts/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,7 @@ $fake->assertRequest(function ($requests) {
expect($requests[0]->provider)->toBe('anthropic');
expect($requests[0]->model)->toBe('claude-3-sonnet');
});

// Assert provider configuration
$fake->assertProviderConfig(['api_key' => 'sk-1234']);
```
4 changes: 4 additions & 0 deletions docs/core-concepts/text-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ Under the hood we use Laravel's [HTTP client](https://laravel.com/docs/11.x/http

Under the hood we use Laravel's [HTTP client](https://laravel.com/docs/11.x/http-client#main-content). You can use this method to set [retries](https://laravel.com/docs/11.x/http-client#retries) e.g. `->withClientRetry(3, 100)`.

`usingProviderConfig`

This allows for complete or partial override of the providers configuration. This is great for multi-tenant applications where users supply their own API keys. These values are merged with the original configuration allowing for partial or complete config override.

## Response Handling

The response object provides rich access to the generation results:
Expand Down
13 changes: 13 additions & 0 deletions src/Concerns/BuildsTextRequests.php
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ trait BuildsTextRequests

protected string $model;

/** @var array<string, mixed> */
protected array $providerConfig = [];

/** @var array<string, array<string, mixed>> */
protected $providerMeta = [];

Expand Down Expand Up @@ -194,4 +197,14 @@ protected function textRequest(): Request
providerMeta: $this->providerMeta,
);
}

/**
* @param array<string, mixed> $config
*/
public function usingProviderConfig(array $config): self
{
$this->providerConfig = $config;

return $this;
}
}
4 changes: 3 additions & 1 deletion src/Prism.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ public function __construct(
private readonly PrismFake $fake
) {}

public function resolve(ProviderEnum|string $name): Provider
public function resolve(ProviderEnum|string $name, array $providerConfig = []): Provider
{
$this->fake->setProviderConfig($providerConfig);

return $this->fake;
}
});
Expand Down
16 changes: 7 additions & 9 deletions src/PrismManager.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ public function __construct(
) {}

/**
* @param array<string, mixed> $providerConfig
*
* @throws InvalidArgumentException
*/
public function resolve(ProviderEnum|string $name): Provider
public function resolve(ProviderEnum|string $name, array $providerConfig = []): Provider
{
$name = $this->resolveName($name);

$config = $this->getConfig($name) ?? [];
$config = array_merge($this->getConfig($name), $providerConfig);

if (isset($this->customCreators[$name])) {
return $this->callCustomCreator($name, $config);
Expand Down Expand Up @@ -127,15 +129,11 @@ protected function callCustomCreator(string $provider, array $config): Provider
}

/**
* @return null|array<string, mixed>
* @return array<string, mixed>
*/
protected function getConfig(string $name): ?array
protected function getConfig(string $name): array
{
if ($name !== '' && $name !== '0') {
return config("prism.providers.{$name}");
}

return ['driver' => 'null'];
return config("prism.providers.{$name}", []);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/Structured/Generator.php
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ protected function decodeObject(string $responseText): ?array
protected function sendProviderRequest(): ProviderResponse
{
$response = resolve(PrismManager::class)
->resolve($this->provider)
->resolve($this->provider, $this->providerConfig)
->structured($this->structuredRequest());

$responseMessage = new AssistantMessage(
Expand Down
22 changes: 22 additions & 0 deletions src/Testing/PrismFake.php
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class PrismFake implements Provider
/** @var array<int, StructuredRequest|TextRequest|EmbeddingRequest> */
protected array $recorded = [];

/** @var array<string, mixed> */
protected $providerConfig = [];

/**
* @param array<int, ProviderResponse|EmbeddingResponse> $responses
*/
Expand Down Expand Up @@ -68,6 +71,14 @@ public function structured(StructuredRequest $request): ProviderResponse
);
}

/**
* @param array<string, mixed> $config
*/
public function setProviderConfig(array $config): void
{
$this->providerConfig = $config;
}

/**
* @param Closure(array<int, StructuredRequest|TextRequest|EmbeddingRequest>):void $fn
*/
Expand All @@ -89,6 +100,17 @@ public function assertPrompt(string $prompt): void
);
}

/**
* @param array<string, mixed> $providerConfig
*/
public function assertProviderConfig(array $providerConfig): void
{
PHPUnit::assertEqualsCanonicalizing(
$providerConfig,
$this->providerConfig
);
}

/**
* Assert number of calls made
*/
Expand Down
2 changes: 1 addition & 1 deletion src/Text/Generator.php
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public function generate(): Response
protected function sendProviderRequest(): ProviderResponse
{
$response = resolve(PrismManager::class)
->resolve($this->provider)
->resolve($this->provider, $this->providerConfig)
->text($this->textRequest());

$responseMessage = new AssistantMessage(
Expand Down
38 changes: 38 additions & 0 deletions tests/Generators/StructuredGeneratorTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use EchoLabs\Prism\ValueObjects\Messages\UserMessage;
use EchoLabs\Prism\ValueObjects\ToolCall;
use EchoLabs\Prism\ValueObjects\Usage;
use Illuminate\Contracts\Foundation\Application;
use InvalidArgumentException;
use Tests\TestDoubles\TestProvider;

Expand Down Expand Up @@ -358,3 +359,40 @@
new UserMessage('Who are you?'),
]);
});

it('allows for custom provider configuration', function (): void {
$provider = new TestProvider;

$schema = new ObjectSchema(
'model',
'An object representing you, a Large Language Model',
[
new StringSchema('name', 'your name'),
]
);

$provider->withResponseChain([
new ProviderResponse(
text: json_encode(['name' => 'Nyx']),
toolCalls: [],
usage: new Usage(10, 10),
finishReason: FinishReason::Stop,
response: ['id' => '123', 'model' => 'claude-3-5-sonnet-20240620']
),
]);

resolve(PrismManager::class)
->extend('test', function (Application $app, array $config) use ($provider): \Tests\TestDoubles\TestProvider {

expect($config)->toBe(['api_key' => '1234']);

return $provider;
});

(new Generator)
->using('test', 'latest')
->withPrompt('Who are you?')
->withSchema($schema)
->usingProviderConfig(['api_key' => '1234'])
->generate();
});
29 changes: 29 additions & 0 deletions tests/Generators/TextGeneratorTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
use EchoLabs\Prism\ValueObjects\Messages\UserMessage;
use EchoLabs\Prism\ValueObjects\ToolCall;
use EchoLabs\Prism\ValueObjects\Usage;
use Illuminate\Contracts\Foundation\Application;
use Tests\TestDoubles\TestProvider;

it('correctly resolves a provider', function (): void {
Expand Down Expand Up @@ -316,3 +317,31 @@
new UserMessage('Who are you?'),
]);
});

it('allows for custom provider configuration', function (): void {
$provider = new TestProvider;

$provider->withResponseChain([
new ProviderResponse(
text: 'I\'m Nyx!',
toolCalls: [],
usage: new Usage(10, 10),
finishReason: FinishReason::Stop,
response: ['id' => '123', 'model' => 'claude-3-5-sonnet-20240620']
),
]);

resolve(PrismManager::class)
->extend('test', function (Application $app, array $config) use ($provider): \Tests\TestDoubles\TestProvider {

expect($config)->toBe(['api_key' => '1234']);

return $provider;
});

(new Generator)
->using('test', 'latest')
->withPrompt('Who are you?')
->usingProviderConfig(['api_key' => '1234'])
->generate();
});
15 changes: 15 additions & 0 deletions tests/PrismManagerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

namespace Tests;

use EchoLabs\Prism\Contracts\Provider as ContractsProvider;
use EchoLabs\Prism\Enums\Provider;
use EchoLabs\Prism\PrismManager;
use EchoLabs\Prism\Providers\Anthropic\Anthropic;
use EchoLabs\Prism\Providers\Mistral\Mistral;
use EchoLabs\Prism\Providers\Ollama\Ollama;
use EchoLabs\Prism\Providers\OpenAI\OpenAI;
use EchoLabs\Prism\Providers\XAI\XAI;
use Illuminate\Contracts\Foundation\Application;
use Mockery;

it('can resolve Anthropic', function (): void {
$manager = new PrismManager($this->app);
Expand Down Expand Up @@ -46,3 +49,15 @@
expect($manager->resolve(Provider::XAI))->toBeInstanceOf(XAI::class);
expect($manager->resolve('xai'))->toBeInstanceOf(XAI::class);
});

it('allows for custom provider configuration', function (): void {
$manager = new PrismManager($this->app);

$manager->extend('test', function (Application $app, array $config) {
expect($config)->toBe(['api_key' => '1234']);

return Mockery::mock(ContractsProvider::class);
});

$manager->resolve('test', ['api_key' => '1234']);
});
20 changes: 20 additions & 0 deletions tests/Testing/PrismFakeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,23 @@
->withPrompt('What is the meaning of life?')
->generate();
});

it('asserts provider config', function (): void {
$fake = Prism::fake([
new ProviderResponse(
text: 'The meaning of life is 42',
toolCalls: [],
usage: new Usage(42, 42),
finishReason: FinishReason::Stop,
response: ['id' => 'cpl_1234', 'model' => 'claude-3-sonnet'],
),
]);

Prism::text()
->using('anthropic', 'claude-3-sonnet')
->withPrompt('What is the meaning of life?')
->usingProviderConfig(['api_key' => '1234'])
->generate();

$fake->assertProviderConfig(['api_key' => '1234']);
});
Loading