diff --git a/README.md b/README.md index 9cf8d80..c069630 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ## Features - 🤖 **Model Providers** - Get detailed model information with type-safe responses from various model providers (OpenAI, Ollama, etc.) -- 💾 **Simple Cache** - PSR-16 Simple Cache support for caching model information +- 💾 **Caching** - PSR-16 Simple Cache support for caching model information - 🔌 **Extensibility** - Easily add support for additional model providers ## Requirements diff --git a/src/Contracts/ModelInfoProvider.php b/src/Contracts/ModelInfoProvider.php index dfc8c24..0ab6d77 100644 --- a/src/Contracts/ModelInfoProvider.php +++ b/src/Contracts/ModelInfoProvider.php @@ -12,7 +12,7 @@ interface ModelInfoProvider /** * Get the available models for a given provider. * - * @return array + * @return array */ public function getModels(ModelProvider $modelProvider): array; diff --git a/src/Data/ModelInfo.php b/src/Data/ModelInfo.php index dc5fa71..e922bf7 100644 --- a/src/Data/ModelInfo.php +++ b/src/Data/ModelInfo.php @@ -9,12 +9,13 @@ use Cortex\ModelInfo\Enums\ModelProvider; /** - * @phpstan-type ModelInfoData array{name: string, provider: string|ModelProvider, type: string|ModelType, max_input_tokens: int|null, max_output_tokens: int|null, input_cost_per_token?: float, output_cost_per_token?: float, features?: array, is_deprecated?: bool} + * @phpstan-type ModelInfoData array{name: string, provider: string|ModelProvider, type: string|ModelType, max_input_tokens: int|null, max_output_tokens: int|null, input_cost_per_token?: float, output_cost_per_token?: float, features?: array, is_deprecated?: bool, metadata?: array} */ readonly class ModelInfo { /** * @param array<\Cortex\ModelInfo\Enums\ModelFeature> $features + * @param array $metadata */ public function __construct( public string $name, @@ -22,10 +23,11 @@ public function __construct( public ModelType $type, public ?int $maxInputTokens, public ?int $maxOutputTokens, - public float $inputCostPerToken, - public float $outputCostPerToken, + public ?float $inputCostPerToken, + public ?float $outputCostPerToken, public array $features, public bool $isDeprecated = false, + public array $metadata = [], ) {} public function supportsFeature(ModelFeature $modelFeature): bool @@ -33,6 +35,11 @@ public function supportsFeature(ModelFeature $modelFeature): bool return in_array($modelFeature, $this->features, true); } + public function getMetadata(string $key): mixed + { + return $this->metadata[$key] ?? null; + } + /** * @param ModelInfoData $data */ @@ -52,10 +59,11 @@ public static function createFromArray(array $data): self $type, $data['max_input_tokens'] ?? null, $data['max_output_tokens'] ?? null, - $data['input_cost_per_token'] ?? 0.0, - $data['output_cost_per_token'] ?? 0.0, + $data['input_cost_per_token'] ?? null, + $data['output_cost_per_token'] ?? null, $data['features'] ?? [], $data['is_deprecated'] ?? false, + $data['metadata'] ?? [], ); } } diff --git a/src/Enums/ModelProvider.php b/src/Enums/ModelProvider.php index 690116c..e6ca026 100644 --- a/src/Enums/ModelProvider.php +++ b/src/Enums/ModelProvider.php @@ -6,6 +6,7 @@ use Cortex\ModelInfo\Data\ModelInfo; use Cortex\ModelInfo\ModelInfoFactory; +use Psr\Container\ContainerExceptionInterface; use Cortex\ModelInfo\Providers\Concerns\DiscoversPsrImplementations; enum ModelProvider: string @@ -29,7 +30,7 @@ enum ModelProvider: string /** * @param array|null $modelInfoProviders * - * @return array + * @return array */ public function models(?array $modelInfoProviders = null): array { @@ -81,11 +82,13 @@ public static function modelInfoFactory(?array $modelInfoProviders = null): Mode { $container = self::discoverContainer(); - if ($container?->has(ModelInfoFactory::class) === true) { - // @phpstan-ignore return.type - return $container->get(ModelInfoFactory::class); + try { + /** @var \Cortex\ModelInfo\ModelInfoFactory $factory */ + $factory = $container?->get(ModelInfoFactory::class); + } catch (ContainerExceptionInterface) { + // } - return new ModelInfoFactory($modelInfoProviders); + return $factory ?? new ModelInfoFactory($modelInfoProviders); } } diff --git a/src/Enums/ModelType.php b/src/Enums/ModelType.php index 3a7daf5..f581836 100644 --- a/src/Enums/ModelType.php +++ b/src/Enums/ModelType.php @@ -13,5 +13,5 @@ enum ModelType: string case TextToSpeech = 'text_to_speech'; case SpeechToText = 'speech_to_text'; case Moderation = 'moderation'; - case Other = 'other'; + case Unknown = 'unknown'; } diff --git a/src/ModelInfoFactory.php b/src/ModelInfoFactory.php index f3655c6..98f37e5 100644 --- a/src/ModelInfoFactory.php +++ b/src/ModelInfoFactory.php @@ -11,6 +11,7 @@ use Cortex\ModelInfo\Exceptions\ModelInfoException; use Cortex\ModelInfo\Providers\OllamaModelInfoProvider; use Cortex\ModelInfo\Providers\LiteLLMModelInfoProvider; +use Cortex\ModelInfo\Providers\LMStudioModelInfoProvider; use Cortex\ModelInfo\Providers\Concerns\DiscoversPsrImplementations; class ModelInfoFactory @@ -36,7 +37,7 @@ public function __construct( } /** - * @return array + * @return array */ public function getModels(ModelProvider $modelProvider): array { @@ -59,7 +60,7 @@ public function getModelInfo(ModelProvider $modelProvider, string $model): ?Mode /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array + * @return array */ public function getModelsOrFail(ModelProvider $modelProvider): array { @@ -123,6 +124,7 @@ protected static function defaultModelInfoProviders(): array { return [ new OllamaModelInfoProvider(), + new LMStudioModelInfoProvider(), new LiteLLMModelInfoProvider(), ]; } diff --git a/src/Providers/CustomModelInfoProvider.php b/src/Providers/CustomModelInfoProvider.php index 21842fa..090b525 100644 --- a/src/Providers/CustomModelInfoProvider.php +++ b/src/Providers/CustomModelInfoProvider.php @@ -40,17 +40,9 @@ public function supportedModelProviders(): array return ModelProvider::cases(); } - /** - * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException - * - * @return array - */ public function getModels(ModelProvider $modelProvider): array { - return array_map( - fn(ModelInfo $model): string => $model->name, - $this->models, - ); + return $this->models; } public function getModelInfo(ModelProvider $modelProvider, string $model): ModelInfo diff --git a/src/Providers/LMStudioModelInfoProvider.php b/src/Providers/LMStudioModelInfoProvider.php index ae3a13a..fc48135 100644 --- a/src/Providers/LMStudioModelInfoProvider.php +++ b/src/Providers/LMStudioModelInfoProvider.php @@ -14,7 +14,7 @@ use Cortex\ModelInfo\Providers\Concerns\MakesRequests; /** - * @phpstan-type ModelInfoResponse array{id: string, object: string, type: string, max_context_length: int, type: ?string} + * @phpstan-type LMStudioModelInfoResponse array{id: string, object: string, type: string, max_context_length: int, type: ?string} */ class LMStudioModelInfoProvider implements ModelInfoProvider { @@ -38,7 +38,7 @@ public function supportedModelProviders(): array /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array + * @return array */ public function getModels(ModelProvider $modelProvider): array { @@ -46,26 +46,30 @@ public function getModels(ModelProvider $modelProvider): array $body = $this->getModelsResponse(); - $models = array_map( - // @phpstan-ignore return.type,argument.type - fn(array $model): string => $model['id'], + return array_values(array_map( + fn(array $model): ModelInfo => self::mapModelInfo($model), $body['data'], - ); - - return array_values($models); + )); } public function getModelInfo(ModelProvider $modelProvider, string $model): ModelInfo { $this->checkSupportOrFail($modelProvider); - $body = $this->getModelInfoResponse($model); - $type = $body['type'] ?? ''; + return self::mapModelInfo( + $this->getModelInfoResponse($model), + ); + } + /** + * @param LMStudioModelInfoResponse $body + */ + protected static function mapModelInfo(array $body): ModelInfo + { return new ModelInfo( - name: $model, + name: $body['id'], provider: ModelProvider::LMStudio, - type: self::getModelType($type), + type: self::getModelType($body['type'] ?? ''), maxInputTokens: self::getMaxInputTokens($body['max_context_length']), maxOutputTokens: null, inputCostPerToken: 0.0, @@ -75,7 +79,7 @@ public function getModelInfo(ModelProvider $modelProvider, string $model): Model } /** - * @param ModelInfoResponse $body + * @param LMStudioModelInfoResponse $body * * @return array */ @@ -96,7 +100,7 @@ protected static function getModelType(string $type): ModelType return match ($type) { 'llm' => ModelType::Chat, 'embeddings' => ModelType::Embedding, - default => ModelType::Other, + default => ModelType::Unknown, }; } @@ -108,7 +112,7 @@ protected static function getMaxInputTokens(int $maxContextLength): ?int /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return ModelInfoResponse + * @return LMStudioModelInfoResponse */ protected function getModelInfoResponse(string $model): array { @@ -127,7 +131,7 @@ protected function getModelInfoResponse(string $model): array /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array{data: array{id: string}} + * @return array{data: array} */ protected function getModelsResponse(): array { diff --git a/src/Providers/LiteLLMModelInfoProvider.php b/src/Providers/LiteLLMModelInfoProvider.php index 95a0f19..fc280f3 100644 --- a/src/Providers/LiteLLMModelInfoProvider.php +++ b/src/Providers/LiteLLMModelInfoProvider.php @@ -48,7 +48,7 @@ public function supportedModelProviders(): array /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array + * @return array */ public function getModels(ModelProvider $modelProvider): array { @@ -61,7 +61,11 @@ public function getModels(ModelProvider $modelProvider): array fn(array $model): bool => $model['litellm_provider'] === $modelProvider->value, ); - return array_keys($models); + return array_values(array_map( + fn(array $modelInfo, string $model): ModelInfo => self::mapModelInfo($modelProvider, $model, $modelInfo), + $models, + array_keys($models), + )); } /** @@ -90,6 +94,14 @@ public function getModelInfo(ModelProvider $modelProvider, string $model): Model throw new ModelInfoException('Model not found'); } + return self::mapModelInfo($modelProvider, $model, $modelInfo); + } + + protected static function mapModelInfo( + ModelProvider $modelProvider, + string $model, + array $modelInfo, + ): ModelInfo { return new ModelInfo( name: $model, provider: $modelProvider, @@ -162,7 +174,7 @@ protected static function mapModelType(string $type): ModelType 'audio_speech' => ModelType::TextToSpeech, 'audio_transcription' => ModelType::SpeechToText, 'moderation' => ModelType::Moderation, - default => ModelType::Other, + default => ModelType::Unknown, }; } diff --git a/src/Providers/OllamaModelInfoProvider.php b/src/Providers/OllamaModelInfoProvider.php index 03350b9..3ef6870 100644 --- a/src/Providers/OllamaModelInfoProvider.php +++ b/src/Providers/OllamaModelInfoProvider.php @@ -13,6 +13,10 @@ use Cortex\ModelInfo\Providers\Concerns\ChecksSupport; use Cortex\ModelInfo\Providers\Concerns\MakesRequests; +/** + * @phpstan-type OllamaModelsResponse array{name: string} + * @phpstan-type OllamaModelInfoResponse array{name: string, model_info: array|null, capabilities: array|null} + */ class OllamaModelInfoProvider implements ModelInfoProvider { use ChecksSupport; @@ -35,7 +39,7 @@ public function supportedModelProviders(): array /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array + * @return array */ public function getModels(ModelProvider $modelProvider): array { @@ -43,23 +47,34 @@ public function getModels(ModelProvider $modelProvider): array $body = $this->getModelsResponse(); - return array_map( - // @phpstan-ignore return.type,argument.type - fn(array $model): string => $model['name'], + return array_values(array_map( + fn(array $model): ModelInfo => self::mapModelInfo($model), $body['models'], - ); + )); } + /** + * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException + */ public function getModelInfo(ModelProvider $modelProvider, string $model): ModelInfo { $this->checkSupportOrFail($modelProvider); - $body = $this->getModelInfoResponse($model); - $modelInfo = $body['model_info'] ?? []; - $capabilities = $body['capabilities'] ?? []; + return self::mapModelInfo( + $this->getModelInfoResponse($model), + ); + } + + /** + * @param OllamaModelInfoResponse|OllamaModelsResponse $modelResponseBody + */ + protected static function mapModelInfo(array $modelResponseBody): ModelInfo + { + $modelInfo = $modelResponseBody['model_info'] ?? []; + $capabilities = $modelResponseBody['capabilities'] ?? []; return new ModelInfo( - name: $model, + name: $modelResponseBody['name'], provider: ModelProvider::Ollama, type: self::getModelType($capabilities), maxInputTokens: self::getMaxInputTokens($modelInfo), @@ -104,7 +119,7 @@ protected static function getModelType(array $capabilities): ModelType return match (true) { in_array('completion', $capabilities, true) => ModelType::Chat, in_array('embedding', $capabilities, true) => ModelType::Embedding, - default => ModelType::Other, + default => ModelType::Unknown, }; } @@ -135,7 +150,7 @@ protected static function getMaxInputTokens(array $modelInfo): ?int /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array{model_info: array|null, capabilities: array|null} + * @return OllamaModelInfoResponse */ protected function getModelInfoResponse(string $model): array { @@ -154,7 +169,7 @@ protected function getModelInfoResponse(string $model): array /** * @throws \Cortex\ModelInfo\Exceptions\ModelInfoException * - * @return array{models: array{name: string}} + * @return array{models: array} */ protected function getModelsResponse(): array { diff --git a/tests/Unit/ModelInfoFactoryTest.php b/tests/Unit/ModelInfoFactoryTest.php index a22bf8b..a46ab0c 100644 --- a/tests/Unit/ModelInfoFactoryTest.php +++ b/tests/Unit/ModelInfoFactoryTest.php @@ -39,6 +39,7 @@ ], ])), new Response(body: json_encode([ + 'name' => 'gemma3:12b', 'model_info' => [ 'mock.context_length' => 1024 * 8, ], @@ -71,11 +72,11 @@ $ollamaModels = $factory->getModels(ModelProvider::Ollama); $customModels = $factory->getModels(ModelProvider::Custom); - expect($ollamaModels)->toBeArray()->toHaveCount(1); - expect($ollamaModels[0])->toEqual('gemma3:12b'); + expect($ollamaModels)->toBeArray()->toHaveCount(1)->toContainOnlyInstancesOf(ModelInfo::class); + expect($ollamaModels[0]->name)->toEqual('gemma3:12b'); - expect($customModels)->toBeArray()->toHaveCount(1); - expect($customModels[0])->toEqual('foobar'); + expect($customModels)->toBeArray()->toHaveCount(1)->toContainOnlyInstancesOf(ModelInfo::class); + expect($customModels[0]->name)->toEqual('foobar'); $ollamaModelInfo = $factory->getModelInfo(ModelProvider::Ollama, 'gemma3:12b'); $customModelInfo = $factory->getModelInfo(ModelProvider::Custom, 'foobar'); diff --git a/tests/Unit/Providers/CustomModelInfoProviderTest.php b/tests/Unit/Providers/CustomModelInfoProviderTest.php index 1ae85cb..432dd54 100644 --- a/tests/Unit/Providers/CustomModelInfoProviderTest.php +++ b/tests/Unit/Providers/CustomModelInfoProviderTest.php @@ -52,7 +52,7 @@ $models = $provider->getModels(ModelProvider::Custom); - expect($models)->toBeArray()->toHaveCount(2)->toContain('custom-model', 'gpt-4o'); + expect($models)->toBeArray()->toContainOnlyInstancesOf(ModelInfo::class); }); test('it can get the model info', function (): void { diff --git a/tests/Unit/Providers/LMStudioModelInfoProviderTest.php b/tests/Unit/Providers/LMStudioModelInfoProviderTest.php index f2b949d..0680033 100644 --- a/tests/Unit/Providers/LMStudioModelInfoProviderTest.php +++ b/tests/Unit/Providers/LMStudioModelInfoProviderTest.php @@ -51,10 +51,12 @@ $models = $provider->getModels(ModelProvider::LMStudio); - expect($models)->toBeArray()->toBe([ - 'text-embedding-nomic-embed-text-v1.5', - 'qwen2.5-14b-instruct-mlx', - ]); + expect($models)->toBeArray() + ->toHaveCount(2) + ->toContainOnlyInstancesOf(ModelInfo::class); + + expect($models[0]->name)->toBe('text-embedding-nomic-embed-text-v1.5'); + expect($models[1]->name)->toBe('qwen2.5-14b-instruct-mlx'); }); test('it can get the model info', function (): void { diff --git a/tests/Unit/Providers/LiteLLMModelInfoProviderTest.php b/tests/Unit/Providers/LiteLLMModelInfoProviderTest.php index 94d434f..c01213d 100644 --- a/tests/Unit/Providers/LiteLLMModelInfoProviderTest.php +++ b/tests/Unit/Providers/LiteLLMModelInfoProviderTest.php @@ -19,9 +19,11 @@ new Response(body: json_encode([ 'gpt-4o' => [ 'litellm_provider' => 'openai', + 'mode' => 'chat', ], 'gpt-3.5-turbo' => [ 'litellm_provider' => 'openai', + 'mode' => 'chat', ], ])), ); @@ -32,7 +34,7 @@ $models = $provider->getModels(ModelProvider::OpenAI); - expect($models)->toBeArray()->toHaveCount(2)->toContain('gpt-4o', 'gpt-3.5-turbo'); + expect($models)->toBeArray()->toContainOnlyInstancesOf(ModelInfo::class); }); test('it can get the model info', function (): void { diff --git a/tests/Unit/Providers/OllamaModelInfoProviderTest.php b/tests/Unit/Providers/OllamaModelInfoProviderTest.php index cec231b..eb237fe 100644 --- a/tests/Unit/Providers/OllamaModelInfoProviderTest.php +++ b/tests/Unit/Providers/OllamaModelInfoProviderTest.php @@ -45,12 +45,17 @@ $models = $provider->getModels(ModelProvider::Ollama); - expect($models)->toBeArray()->toHaveCount(1)->toContain('gemma3:12b'); + expect($models)->toBeArray() + ->toHaveCount(1) + ->toContainOnlyInstancesOf(ModelInfo::class); + + expect($models[0]->name)->toBe('gemma3:12b'); }); test('it can get the model info', function (): void { $client = $this->mockHttpClient( new Response(body: json_encode([ + 'name' => 'mistral-small3.1', 'model_info' => [ 'mock.context_length' => 1024 * 8, ],