Skip to content

feat: add CerebrasProvider #1867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
::: pydantic_ai.providers.grok

::: pydantic_ai.providers.together

::: pydantic_ai.providers.cerebras
18 changes: 18 additions & 0 deletions docs/models/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,21 @@ model = OpenAIModel(
agent = Agent(model)
...
```

### Cerebras

Go to [Cerebras](https://www.cerebras.ai/) and create an API key in your account settings.
Once you have the API key, you can use it with the `CerebrasProvider`:

```python
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.cerebras import CerebrasProvider

model = OpenAIModel(
'qwen-3-32b', # model library available at https://inference-docs.cerebras.ai/introduction
provider=CerebrasProvider(api_key='your-cerebras-api-key'),
)
agent = Agent(model)
...
```
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
from .cohere import CohereModel

return CohereModel(model_name, provider=provider)
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'cerebras'):
from .openai import OpenAIModel

return OpenAIModel(model_name, provider=provider)
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(
self,
model_name: OpenAIModelName,
*,
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'cerebras']
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
system_prompt_role: OpenAISystemPromptRole | None = None,
Expand Down Expand Up @@ -537,7 +537,7 @@ def __init__(
self,
model_name: OpenAIModelName,
*,
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'cerebras']
| Provider[AsyncOpenAI] = 'openai',
profile: ModelProfileSpec | None = None,
):
Expand Down
90 changes: 90 additions & 0 deletions pydantic_ai_slim/pydantic_ai/profiles/cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations as _annotations

import warnings

from pydantic_ai.exceptions import UserError

from ._json_schema import JsonSchema, JsonSchemaTransformer


class CerebrasJsonSchemaTransformer(JsonSchemaTransformer):
"""Transforms the JSON Schema from Pydantic to be suitable for Cerebras.

Cerebras supports a subset of OpenAI's structured output capabilities, which is documented here:
- https://inference-docs.cerebras.ai/capabilities/structured-outputs#advanced-schema-features
- https://inference-docs.cerebras.ai/capabilities/structured-outputs#variations-from-openai's-structured-output-capabilities
- https://inference-docs.cerebras.ai/capabilities/tool-use

TODO: `transform` method is based on GoogleJsonSchemaTransformer, and it doesn't handle all cases mentioned in links above.
"""

def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=False)

def transform(self, schema: JsonSchema) -> JsonSchema:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please reduce the duplication here with the GoogleJsonSchemaTransformer, perhaps by subclassing it?

additional_properties = schema.pop(
'additionalProperties', None
) # don't pop yet so it's included in the warning
if additional_properties:
original_schema = {**schema, 'additionalProperties': additional_properties}
warnings.warn(
'`additionalProperties` is not supported by Cerebras; it will be removed from the tool JSON schema.'
f' Full schema: {self.schema}\n\n'
f'Source of additionalProperties within the full schema: {original_schema}\n\n'
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n'
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub"
' and we will fix this behavior.',
UserWarning,
)

schema.pop('title', None)
schema.pop('default', None)
schema.pop('$schema', None)
if (const := schema.pop('const', None)) is not None: # pragma: no cover
schema['enum'] = [const]
schema.pop('discriminator', None)
schema.pop('examples', None)

# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema
# where we add notes about these properties to the field description?
schema.pop('exclusiveMaximum', None)
schema.pop('exclusiveMinimum', None)

# Pydantic will take care of transforming the transformed string values to the correct type.
if enum := schema.get('enum'):
schema['type'] = 'string'
schema['enum'] = [str(val) for val in enum]

type_ = schema.get('type')
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover
# This gets hit when we have a discriminated union
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent
schema['anyOf'] = schema.pop('oneOf')

if type_ == 'string' and (fmt := schema.pop('format', None)):
description = schema.get('description')
if description:
schema['description'] = f'{description} (format: {fmt})'
else:
schema['description'] = f'Format: {fmt}'

if '$ref' in schema:
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Cerebras: {schema["$ref"]}')

if 'prefixItems' in schema:
# prefixItems is not currently supported in Cerebras, so we convert it to items for best compatibility
prefix_items = schema.pop('prefixItems')
items = schema.get('items')
unique_items = [items] if items is not None else []
for item in prefix_items:
if item not in unique_items:
unique_items.append(item)
if len(unique_items) > 1: # pragma: no cover
schema['items'] = {'anyOf': unique_items}
elif len(unique_items) == 1: # pragma: no branch
schema['items'] = unique_items[0]
schema.setdefault('minItems', len(prefix_items))
if items is None: # pragma: no branch
schema.setdefault('maxItems', len(prefix_items))

return schema
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
return None # pragma: no cover


def infer_provider(provider: str) -> Provider[Any]:
def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
"""Infer the provider from the provider name."""
if provider == 'openai':
from .openai import OpenAIProvider
Expand Down Expand Up @@ -107,5 +107,9 @@ def infer_provider(provider: str) -> Provider[Any]:
from .together import TogetherProvider

return TogetherProvider()
elif provider == 'cerebras':
from .cerebras import CerebrasProvider

return CerebrasProvider()
else: # pragma: no cover
raise ValueError(f'Unknown provider: {provider}')
96 changes: 96 additions & 0 deletions pydantic_ai_slim/pydantic_ai/providers/cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations as _annotations

import os
from typing import overload

from httpx import AsyncClient as AsyncHTTPClient
from openai import AsyncOpenAI

from pydantic_ai.exceptions import UserError
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.profiles import ModelProfile
from pydantic_ai.profiles.cerebras import CerebrasJsonSchemaTransformer
from pydantic_ai.profiles.deepseek import deepseek_model_profile
from pydantic_ai.profiles.meta import meta_model_profile
from pydantic_ai.profiles.openai import OpenAIModelProfile
from pydantic_ai.profiles.qwen import qwen_model_profile
from pydantic_ai.providers import Provider

try:
from openai import AsyncOpenAI
except ImportError as _import_error: # pragma: no cover
raise ImportError(
'Please install the `openai` package to use the Cerebras provider, '
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
) from _import_error


class CerebrasProvider(Provider[AsyncOpenAI]):
"""Provider for Cerebras API."""

@property
def name(self) -> str:
return 'cerebras'

@property
def base_url(self) -> str:
return 'https://api.cerebras.ai/v1'

@property
def client(self) -> AsyncOpenAI:
return self._client

def model_profile(self, model_name: str) -> ModelProfile | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do automatic model profile selection here based on the underlying model used, like we do in TogetherProvider. We can still override the json_schema_transformer unconditionally, like we do there as well.

Copy link
Author

@smallstepman smallstepman Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit puzzled on

        provider, model_name = model_name.split('/', 1)
        if provider in provider_to_profile:
            profile = provider_to_profile[provider](model_name)

in TogetherProvider.

I dont fully grasp the idea behind what's meant to be a "provider" in this case. Should CerebrasProvider support alternative model name syntax like Qwen/Qwen-3-32b instead of what's listed in their docs: qwen-3-32b?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smallstepman What we're trying to do is to use the regular model names that Cerebras expects to choose the right model profile for the model in question, based on the prefix. In the case of TogetherProvider their model names have a <provider>/ prefix, but for example in GroqProvider you can see how we match on a simple qwen prefix and then use qwen_model_provider. We'll want to do the same here for all models Cerebras supports that match models we already have model profiles for.

provider_to_profile = {
'deepseek': deepseek_model_profile,
'qwen': qwen_model_profile,
'llama': meta_model_profile,
}
profile = None

try:
model_provider = model_name.split('-')[0]
for provider, profile_func in provider_to_profile.items():
if model_provider.startswith(provider):
profile = profile_func(model_name)
break
except Exception as _: # pragma: no cover
pass

return OpenAIModelProfile(
json_schema_transformer=CerebrasJsonSchemaTransformer, openai_supports_strict_tool_definition=True
).update(profile)

@overload
def __init__(self) -> None: ...

@overload
def __init__(self, *, api_key: str) -> None: ...

@overload
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...

@overload
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
openai_client: AsyncOpenAI | None = None,
http_client: AsyncHTTPClient | None = None,
) -> None:
api_key = api_key or os.getenv('CEREBRAS_API_KEY')
if not api_key and openai_client is None:
raise UserError(
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)`'
'to use the Cerebras provider.'
)

if openai_client is not None:
self._client = openai_client
elif http_client is not None:
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
else:
http_client = cached_async_http_client(provider='cerebras')
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
84 changes: 84 additions & 0 deletions tests/providers/test_cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import re

import httpx
import pytest

from pydantic_ai.exceptions import UserError
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
from pydantic_ai.profiles.cerebras import CerebrasJsonSchemaTransformer
from pydantic_ai.profiles.openai import OpenAIModelProfile

from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
import openai

from pydantic_ai.providers.cerebras import CerebrasProvider

pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')


def test_cerebras_provider():
provider = CerebrasProvider(api_key='api-key')
assert provider.name == 'cerebras'
assert provider.base_url == 'https://api.cerebras.ai/v1'
assert isinstance(provider.client, openai.AsyncOpenAI)
assert provider.client.api_key == 'api-key'


def test_cerebras_provider_need_api_key(env: TestEnv) -> None:
env.remove('CEREBRAS_API_KEY')
with pytest.raises(
UserError,
match=re.escape(
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)`'
'to use the Cerebras provider.'
),
):
CerebrasProvider()


def test_cerebras_provider_pass_http_client() -> None:
http_client = httpx.AsyncClient()
provider = CerebrasProvider(http_client=http_client, api_key='api-key')
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]


def test_cerebras_pass_openai_client() -> None:
openai_client = openai.AsyncOpenAI(api_key='api-key')
provider = CerebrasProvider(openai_client=openai_client)
assert provider.client == openai_client


def test_cerebras_model_profile():
provider = CerebrasProvider(api_key='api-key')

model = provider.model_profile('llama-4-scout-17b-16e-instruct')
assert isinstance(model, OpenAIModelProfile)
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert model.openai_supports_strict_tool_definition is True

model = provider.model_profile('llama3.1-8b')
assert isinstance(model, OpenAIModelProfile)
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert model.openai_supports_strict_tool_definition is True

model = provider.model_profile('llama3.3-70b')
assert isinstance(model, OpenAIModelProfile)
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert model.openai_supports_strict_tool_definition is True

model = provider.model_profile('qwen-3-32b')
assert isinstance(model, OpenAIModelProfile)
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert model.openai_supports_strict_tool_definition is True

model = provider.model_profile('deepseek-r1-distill-llama-70b')
assert isinstance(model, OpenAIModelProfile)
assert model.json_schema_transformer == CerebrasJsonSchemaTransformer
assert model.openai_supports_strict_tool_definition is True

model = provider.model_profile('new-non-existing-model')
assert isinstance(model, OpenAIModelProfile)
assert model.json_schema_transformer == CerebrasJsonSchemaTransformer
assert model.openai_supports_strict_tool_definition is True
2 changes: 2 additions & 0 deletions tests/providers/test_provider_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from pydantic_ai.providers.anthropic import AnthropicProvider
from pydantic_ai.providers.azure import AzureProvider
from pydantic_ai.providers.cerebras import CerebrasProvider
from pydantic_ai.providers.cohere import CohereProvider
from pydantic_ai.providers.deepseek import DeepSeekProvider
from pydantic_ai.providers.fireworks import FireworksProvider
Expand All @@ -30,6 +31,7 @@

test_infer_provider_params = [
('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'),
('cerebras', CerebrasProvider, 'CEREBRAS_API_KEY'),
('cohere', CohereProvider, 'CO_API_KEY'),
('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'),
('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'),
Expand Down