diff --git a/docs/api/providers.md b/docs/api/providers.md index e8086025c..0ae859a19 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -25,3 +25,5 @@ ::: pydantic_ai.providers.grok ::: pydantic_ai.providers.together + +::: pydantic_ai.providers.cerebras diff --git a/docs/models/openai.md b/docs/models/openai.md index 0f18fafa0..9552ec46d 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -407,3 +407,21 @@ model = OpenAIModel( agent = Agent(model) ... ``` + +### Cerebras + +Go to [Cerebras](https://www.cerebras.ai/) and create an API key in your account settings. +Once you have the API key, you can use it with the `CerebrasProvider`: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.cerebras import CerebrasProvider + +model = OpenAIModel( + 'qwen-3-32b', # model library available at https://inference-docs.cerebras.ai/introduction + provider=CerebrasProvider(api_key='your-cerebras-api-key'), +) +agent = Agent(model) +... +``` diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 714080305..b996a966d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -543,7 +543,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: from .cohere import CohereModel return CohereModel(model_name, provider=provider) - elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'): + elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'cerebras'): from .openai import OpenAIModel return OpenAIModel(model_name, provider=provider) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index aeb21ce23..5953b5636 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -170,7 +170,7 @@ def __init__( self, model_name: OpenAIModelName, *, - provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'] + provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'cerebras'] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, system_prompt_role: OpenAISystemPromptRole | None = None, @@ -537,7 +537,7 @@ def __init__( self, model_name: OpenAIModelName, *, - provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'] + provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'cerebras'] | Provider[AsyncOpenAI] = 'openai', profile: ModelProfileSpec | None = None, ): diff --git a/pydantic_ai_slim/pydantic_ai/profiles/cerebras.py b/pydantic_ai_slim/pydantic_ai/profiles/cerebras.py new file mode 100644 index 000000000..a1058532a --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/profiles/cerebras.py @@ -0,0 +1,90 @@ +from __future__ import annotations as _annotations + +import warnings + +from pydantic_ai.exceptions import UserError + +from ._json_schema import JsonSchema, JsonSchemaTransformer + + +class CerebrasJsonSchemaTransformer(JsonSchemaTransformer): + """Transforms the JSON Schema from Pydantic to be suitable for Cerebras. + + Cerebras supports a subset of OpenAI's structured output capabilities, which is documented here: + - https://inference-docs.cerebras.ai/capabilities/structured-outputs#advanced-schema-features + - https://inference-docs.cerebras.ai/capabilities/structured-outputs#variations-from-openai's-structured-output-capabilities + - https://inference-docs.cerebras.ai/capabilities/tool-use + + TODO: `transform` method is based on GoogleJsonSchemaTransformer, and it doesn't handle all cases mentioned in links above. + """ + + def __init__(self, schema: JsonSchema, *, strict: bool | None = None): + super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=False) + + def transform(self, schema: JsonSchema) -> JsonSchema: + additional_properties = schema.pop( + 'additionalProperties', None + ) # don't pop yet so it's included in the warning + if additional_properties: + original_schema = {**schema, 'additionalProperties': additional_properties} + warnings.warn( + '`additionalProperties` is not supported by Cerebras; it will be removed from the tool JSON schema.' + f' Full schema: {self.schema}\n\n' + f'Source of additionalProperties within the full schema: {original_schema}\n\n' + 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n' + "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub" + ' and we will fix this behavior.', + UserWarning, + ) + + schema.pop('title', None) + schema.pop('default', None) + schema.pop('$schema', None) + if (const := schema.pop('const', None)) is not None: # pragma: no cover + schema['enum'] = [const] + schema.pop('discriminator', None) + schema.pop('examples', None) + + # TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema + # where we add notes about these properties to the field description? + schema.pop('exclusiveMaximum', None) + schema.pop('exclusiveMinimum', None) + + # Pydantic will take care of transforming the transformed string values to the correct type. + if enum := schema.get('enum'): + schema['type'] = 'string' + schema['enum'] = [str(val) for val in enum] + + type_ = schema.get('type') + if 'oneOf' in schema and 'type' not in schema: # pragma: no cover + # This gets hit when we have a discriminated union + # Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent + schema['anyOf'] = schema.pop('oneOf') + + if type_ == 'string' and (fmt := schema.pop('format', None)): + description = schema.get('description') + if description: + schema['description'] = f'{description} (format: {fmt})' + else: + schema['description'] = f'Format: {fmt}' + + if '$ref' in schema: + raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Cerebras: {schema["$ref"]}') + + if 'prefixItems' in schema: + # prefixItems is not currently supported in Cerebras, so we convert it to items for best compatibility + prefix_items = schema.pop('prefixItems') + items = schema.get('items') + unique_items = [items] if items is not None else [] + for item in prefix_items: + if item not in unique_items: + unique_items.append(item) + if len(unique_items) > 1: # pragma: no cover + schema['items'] = {'anyOf': unique_items} + elif len(unique_items) == 1: # pragma: no branch + schema['items'] = unique_items[0] + schema.setdefault('minItems', len(prefix_items)) + if items is None: # pragma: no branch + schema.setdefault('maxItems', len(prefix_items)) + + return schema diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 86dd8ec74..f63289c7b 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -48,7 +48,7 @@ def model_profile(self, model_name: str) -> ModelProfile | None: return None # pragma: no cover -def infer_provider(provider: str) -> Provider[Any]: +def infer_provider(provider: str) -> Provider[Any]: # noqa: C901 """Infer the provider from the provider name.""" if provider == 'openai': from .openai import OpenAIProvider @@ -107,5 +107,9 @@ def infer_provider(provider: str) -> Provider[Any]: from .together import TogetherProvider return TogetherProvider() + elif provider == 'cerebras': + from .cerebras import CerebrasProvider + + return CerebrasProvider() else: # pragma: no cover raise ValueError(f'Unknown provider: {provider}') diff --git a/pydantic_ai_slim/pydantic_ai/providers/cerebras.py b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py new file mode 100644 index 000000000..7cf766969 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/cerebras.py @@ -0,0 +1,96 @@ +from __future__ import annotations as _annotations + +import os +from typing import overload + +from httpx import AsyncClient as AsyncHTTPClient +from openai import AsyncOpenAI + +from pydantic_ai.exceptions import UserError +from pydantic_ai.models import cached_async_http_client +from pydantic_ai.profiles import ModelProfile +from pydantic_ai.profiles.cerebras import CerebrasJsonSchemaTransformer +from pydantic_ai.profiles.deepseek import deepseek_model_profile +from pydantic_ai.profiles.meta import meta_model_profile +from pydantic_ai.profiles.openai import OpenAIModelProfile +from pydantic_ai.profiles.qwen import qwen_model_profile +from pydantic_ai.providers import Provider + +try: + from openai import AsyncOpenAI +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `openai` package to use the Cerebras provider, ' + 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' + ) from _import_error + + +class CerebrasProvider(Provider[AsyncOpenAI]): + """Provider for Cerebras API.""" + + @property + def name(self) -> str: + return 'cerebras' + + @property + def base_url(self) -> str: + return 'https://api.cerebras.ai/v1' + + @property + def client(self) -> AsyncOpenAI: + return self._client + + def model_profile(self, model_name: str) -> ModelProfile | None: + provider_to_profile = { + 'deepseek': deepseek_model_profile, + 'qwen': qwen_model_profile, + 'llama': meta_model_profile, + } + profile = None + + try: + model_provider = model_name.split('-')[0] + for provider, profile_func in provider_to_profile.items(): + if model_provider.startswith(provider): + profile = profile_func(model_name) + break + except Exception as _: # pragma: no cover + pass + + return OpenAIModelProfile( + json_schema_transformer=CerebrasJsonSchemaTransformer, openai_supports_strict_tool_definition=True + ).update(profile) + + @overload + def __init__(self) -> None: ... + + @overload + def __init__(self, *, api_key: str) -> None: ... + + @overload + def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... + + @overload + def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... + + def __init__( + self, + *, + api_key: str | None = None, + openai_client: AsyncOpenAI | None = None, + http_client: AsyncHTTPClient | None = None, + ) -> None: + api_key = api_key or os.getenv('CEREBRAS_API_KEY') + if not api_key and openai_client is None: + raise UserError( + 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)`' + 'to use the Cerebras provider.' + ) + + if openai_client is not None: + self._client = openai_client + elif http_client is not None: + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) + else: + http_client = cached_async_http_client(provider='cerebras') + self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) diff --git a/tests/providers/test_cerebras.py b/tests/providers/test_cerebras.py new file mode 100644 index 000000000..2af00e64e --- /dev/null +++ b/tests/providers/test_cerebras.py @@ -0,0 +1,84 @@ +import re + +import httpx +import pytest + +from pydantic_ai.exceptions import UserError +from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer +from pydantic_ai.profiles.cerebras import CerebrasJsonSchemaTransformer +from pydantic_ai.profiles.openai import OpenAIModelProfile + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + import openai + + from pydantic_ai.providers.cerebras import CerebrasProvider + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') + + +def test_cerebras_provider(): + provider = CerebrasProvider(api_key='api-key') + assert provider.name == 'cerebras' + assert provider.base_url == 'https://api.cerebras.ai/v1' + assert isinstance(provider.client, openai.AsyncOpenAI) + assert provider.client.api_key == 'api-key' + + +def test_cerebras_provider_need_api_key(env: TestEnv) -> None: + env.remove('CEREBRAS_API_KEY') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)`' + 'to use the Cerebras provider.' + ), + ): + CerebrasProvider() + + +def test_cerebras_provider_pass_http_client() -> None: + http_client = httpx.AsyncClient() + provider = CerebrasProvider(http_client=http_client, api_key='api-key') + assert provider.client._client == http_client # type: ignore[reportPrivateUsage] + + +def test_cerebras_pass_openai_client() -> None: + openai_client = openai.AsyncOpenAI(api_key='api-key') + provider = CerebrasProvider(openai_client=openai_client) + assert provider.client == openai_client + + +def test_cerebras_model_profile(): + provider = CerebrasProvider(api_key='api-key') + + model = provider.model_profile('llama-4-scout-17b-16e-instruct') + assert isinstance(model, OpenAIModelProfile) + assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert model.openai_supports_strict_tool_definition is True + + model = provider.model_profile('llama3.1-8b') + assert isinstance(model, OpenAIModelProfile) + assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert model.openai_supports_strict_tool_definition is True + + model = provider.model_profile('llama3.3-70b') + assert isinstance(model, OpenAIModelProfile) + assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert model.openai_supports_strict_tool_definition is True + + model = provider.model_profile('qwen-3-32b') + assert isinstance(model, OpenAIModelProfile) + assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer + assert model.openai_supports_strict_tool_definition is True + + model = provider.model_profile('deepseek-r1-distill-llama-70b') + assert isinstance(model, OpenAIModelProfile) + assert model.json_schema_transformer == CerebrasJsonSchemaTransformer + assert model.openai_supports_strict_tool_definition is True + + model = provider.model_profile('new-non-existing-model') + assert isinstance(model, OpenAIModelProfile) + assert model.json_schema_transformer == CerebrasJsonSchemaTransformer + assert model.openai_supports_strict_tool_definition is True diff --git a/tests/providers/test_provider_names.py b/tests/providers/test_provider_names.py index 5b8362cea..555afb7e7 100644 --- a/tests/providers/test_provider_names.py +++ b/tests/providers/test_provider_names.py @@ -16,6 +16,7 @@ from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.azure import AzureProvider + from pydantic_ai.providers.cerebras import CerebrasProvider from pydantic_ai.providers.cohere import CohereProvider from pydantic_ai.providers.deepseek import DeepSeekProvider from pydantic_ai.providers.fireworks import FireworksProvider @@ -30,6 +31,7 @@ test_infer_provider_params = [ ('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'), + ('cerebras', CerebrasProvider, 'CEREBRAS_API_KEY'), ('cohere', CohereProvider, 'CO_API_KEY'), ('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'), ('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'),