diff --git a/docs/core-concepts/structured-output.md b/docs/core-concepts/structured-output.md index 12a649d..74d28ed 100644 --- a/docs/core-concepts/structured-output.md +++ b/docs/core-concepts/structured-output.md @@ -114,6 +114,9 @@ Structured output supports all the same options as text generation, including: - Message history - Tools and function calling - System prompts +- withClientOptions +- withClientRetry +- usingProviderConfig See the [Text Generation](./text-generation.md) documentation for details on these common settings. diff --git a/docs/core-concepts/testing.md b/docs/core-concepts/testing.md index b54d8b4..ff29c83 100644 --- a/docs/core-concepts/testing.md +++ b/docs/core-concepts/testing.md @@ -199,4 +199,7 @@ $fake->assertRequest(function ($requests) { expect($requests[0]->provider)->toBe('anthropic'); expect($requests[0]->model)->toBe('claude-3-sonnet'); }); + +// Assert provider configuration +$fake->assertProviderConfig(['api_key' => 'sk-1234']); ``` diff --git a/docs/core-concepts/text-generation.md b/docs/core-concepts/text-generation.md index 9ed5545..fb169f2 100644 --- a/docs/core-concepts/text-generation.md +++ b/docs/core-concepts/text-generation.md @@ -137,6 +137,10 @@ Under the hood we use Laravel's [HTTP client](https://laravel.com/docs/11.x/http Under the hood we use Laravel's [HTTP client](https://laravel.com/docs/11.x/http-client#main-content). You can use this method to set [retries](https://laravel.com/docs/11.x/http-client#retries) e.g. `->withClientRetry(3, 100)`. +`usingProviderConfig` + +This allows for complete or partial override of the providers configuration. This is great for multi-tenant applications where users supply their own API keys. These values are merged with the original configuration allowing for partial or complete config override. + ## Response Handling The response object provides rich access to the generation results: diff --git a/src/Concerns/BuildsTextRequests.php b/src/Concerns/BuildsTextRequests.php index f30ae65..535b871 100644 --- a/src/Concerns/BuildsTextRequests.php +++ b/src/Concerns/BuildsTextRequests.php @@ -46,6 +46,9 @@ trait BuildsTextRequests protected string $model; + /** @var array */ + protected array $providerConfig = []; + /** @var array> */ protected $providerMeta = []; @@ -194,4 +197,14 @@ protected function textRequest(): Request providerMeta: $this->providerMeta, ); } + + /** + * @param array $config + */ + public function usingProviderConfig(array $config): self + { + $this->providerConfig = $config; + + return $this; + } } diff --git a/src/Prism.php b/src/Prism.php index 57f2ea3..f4482ab 100644 --- a/src/Prism.php +++ b/src/Prism.php @@ -27,8 +27,10 @@ public function __construct( private readonly PrismFake $fake ) {} - public function resolve(ProviderEnum|string $name): Provider + public function resolve(ProviderEnum|string $name, array $providerConfig = []): Provider { + $this->fake->setProviderConfig($providerConfig); + return $this->fake; } }); diff --git a/src/PrismManager.php b/src/PrismManager.php index 27a8e85..f4701c2 100644 --- a/src/PrismManager.php +++ b/src/PrismManager.php @@ -27,13 +27,15 @@ public function __construct( ) {} /** + * @param array $providerConfig + * * @throws InvalidArgumentException */ - public function resolve(ProviderEnum|string $name): Provider + public function resolve(ProviderEnum|string $name, array $providerConfig = []): Provider { $name = $this->resolveName($name); - $config = $this->getConfig($name) ?? []; + $config = array_merge($this->getConfig($name), $providerConfig); if (isset($this->customCreators[$name])) { return $this->callCustomCreator($name, $config); @@ -127,15 +129,11 @@ protected function callCustomCreator(string $provider, array $config): Provider } /** - * @return null|array + * @return array */ - protected function getConfig(string $name): ?array + protected function getConfig(string $name): array { - if ($name !== '' && $name !== '0') { - return config("prism.providers.{$name}"); - } - - return ['driver' => 'null']; + return config("prism.providers.{$name}", []); } /** diff --git a/src/Structured/Generator.php b/src/Structured/Generator.php index 9c7a52a..ed5b7ca 100644 --- a/src/Structured/Generator.php +++ b/src/Structured/Generator.php @@ -97,7 +97,7 @@ protected function decodeObject(string $responseText): ?array protected function sendProviderRequest(): ProviderResponse { $response = resolve(PrismManager::class) - ->resolve($this->provider) + ->resolve($this->provider, $this->providerConfig) ->structured($this->structuredRequest()); $responseMessage = new AssistantMessage( diff --git a/src/Testing/PrismFake.php b/src/Testing/PrismFake.php index a349da2..2bcb524 100644 --- a/src/Testing/PrismFake.php +++ b/src/Testing/PrismFake.php @@ -24,6 +24,9 @@ class PrismFake implements Provider /** @var array */ protected array $recorded = []; + /** @var array */ + protected $providerConfig = []; + /** * @param array $responses */ @@ -68,6 +71,14 @@ public function structured(StructuredRequest $request): ProviderResponse ); } + /** + * @param array $config + */ + public function setProviderConfig(array $config): void + { + $this->providerConfig = $config; + } + /** * @param Closure(array):void $fn */ @@ -89,6 +100,17 @@ public function assertPrompt(string $prompt): void ); } + /** + * @param array $providerConfig + */ + public function assertProviderConfig(array $providerConfig): void + { + PHPUnit::assertEqualsCanonicalizing( + $providerConfig, + $this->providerConfig + ); + } + /** * Assert number of calls made */ diff --git a/src/Text/Generator.php b/src/Text/Generator.php index b197dd7..3c5805d 100644 --- a/src/Text/Generator.php +++ b/src/Text/Generator.php @@ -50,7 +50,7 @@ public function generate(): Response protected function sendProviderRequest(): ProviderResponse { $response = resolve(PrismManager::class) - ->resolve($this->provider) + ->resolve($this->provider, $this->providerConfig) ->text($this->textRequest()); $responseMessage = new AssistantMessage( diff --git a/tests/Generators/StructuredGeneratorTest.php b/tests/Generators/StructuredGeneratorTest.php index 02aca05..767b8ac 100644 --- a/tests/Generators/StructuredGeneratorTest.php +++ b/tests/Generators/StructuredGeneratorTest.php @@ -20,6 +20,7 @@ use EchoLabs\Prism\ValueObjects\Messages\UserMessage; use EchoLabs\Prism\ValueObjects\ToolCall; use EchoLabs\Prism\ValueObjects\Usage; +use Illuminate\Contracts\Foundation\Application; use InvalidArgumentException; use Tests\TestDoubles\TestProvider; @@ -358,3 +359,40 @@ new UserMessage('Who are you?'), ]); }); + +it('allows for custom provider configuration', function (): void { + $provider = new TestProvider; + + $schema = new ObjectSchema( + 'model', + 'An object representing you, a Large Language Model', + [ + new StringSchema('name', 'your name'), + ] + ); + + $provider->withResponseChain([ + new ProviderResponse( + text: json_encode(['name' => 'Nyx']), + toolCalls: [], + usage: new Usage(10, 10), + finishReason: FinishReason::Stop, + response: ['id' => '123', 'model' => 'claude-3-5-sonnet-20240620'] + ), + ]); + + resolve(PrismManager::class) + ->extend('test', function (Application $app, array $config) use ($provider): \Tests\TestDoubles\TestProvider { + + expect($config)->toBe(['api_key' => '1234']); + + return $provider; + }); + + (new Generator) + ->using('test', 'latest') + ->withPrompt('Who are you?') + ->withSchema($schema) + ->usingProviderConfig(['api_key' => '1234']) + ->generate(); +}); diff --git a/tests/Generators/TextGeneratorTest.php b/tests/Generators/TextGeneratorTest.php index 54dc070..d61d8ff 100644 --- a/tests/Generators/TextGeneratorTest.php +++ b/tests/Generators/TextGeneratorTest.php @@ -16,6 +16,7 @@ use EchoLabs\Prism\ValueObjects\Messages\UserMessage; use EchoLabs\Prism\ValueObjects\ToolCall; use EchoLabs\Prism\ValueObjects\Usage; +use Illuminate\Contracts\Foundation\Application; use Tests\TestDoubles\TestProvider; it('correctly resolves a provider', function (): void { @@ -316,3 +317,31 @@ new UserMessage('Who are you?'), ]); }); + +it('allows for custom provider configuration', function (): void { + $provider = new TestProvider; + + $provider->withResponseChain([ + new ProviderResponse( + text: 'I\'m Nyx!', + toolCalls: [], + usage: new Usage(10, 10), + finishReason: FinishReason::Stop, + response: ['id' => '123', 'model' => 'claude-3-5-sonnet-20240620'] + ), + ]); + + resolve(PrismManager::class) + ->extend('test', function (Application $app, array $config) use ($provider): \Tests\TestDoubles\TestProvider { + + expect($config)->toBe(['api_key' => '1234']); + + return $provider; + }); + + (new Generator) + ->using('test', 'latest') + ->withPrompt('Who are you?') + ->usingProviderConfig(['api_key' => '1234']) + ->generate(); +}); diff --git a/tests/PrismManagerTest.php b/tests/PrismManagerTest.php index e6c9243..2058bb3 100644 --- a/tests/PrismManagerTest.php +++ b/tests/PrismManagerTest.php @@ -4,6 +4,7 @@ namespace Tests; +use EchoLabs\Prism\Contracts\Provider as ContractsProvider; use EchoLabs\Prism\Enums\Provider; use EchoLabs\Prism\PrismManager; use EchoLabs\Prism\Providers\Anthropic\Anthropic; @@ -11,6 +12,8 @@ use EchoLabs\Prism\Providers\Ollama\Ollama; use EchoLabs\Prism\Providers\OpenAI\OpenAI; use EchoLabs\Prism\Providers\XAI\XAI; +use Illuminate\Contracts\Foundation\Application; +use Mockery; it('can resolve Anthropic', function (): void { $manager = new PrismManager($this->app); @@ -46,3 +49,15 @@ expect($manager->resolve(Provider::XAI))->toBeInstanceOf(XAI::class); expect($manager->resolve('xai'))->toBeInstanceOf(XAI::class); }); + +it('allows for custom provider configuration', function (): void { + $manager = new PrismManager($this->app); + + $manager->extend('test', function (Application $app, array $config) { + expect($config)->toBe(['api_key' => '1234']); + + return Mockery::mock(ContractsProvider::class); + }); + + $manager->resolve('test', ['api_key' => '1234']); +}); diff --git a/tests/Testing/PrismFakeTest.php b/tests/Testing/PrismFakeTest.php index 88a2016..0ef8817 100644 --- a/tests/Testing/PrismFakeTest.php +++ b/tests/Testing/PrismFakeTest.php @@ -122,3 +122,23 @@ ->withPrompt('What is the meaning of life?') ->generate(); }); + +it('asserts provider config', function (): void { + $fake = Prism::fake([ + new ProviderResponse( + text: 'The meaning of life is 42', + toolCalls: [], + usage: new Usage(42, 42), + finishReason: FinishReason::Stop, + response: ['id' => 'cpl_1234', 'model' => 'claude-3-sonnet'], + ), + ]); + + Prism::text() + ->using('anthropic', 'claude-3-sonnet') + ->withPrompt('What is the meaning of life?') + ->usingProviderConfig(['api_key' => '1234']) + ->generate(); + + $fake->assertProviderConfig(['api_key' => '1234']); +});