-
Notifications
You must be signed in to change notification settings - Fork 999
feat: add CerebrasProvider
#1867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ce3b05f
950bd6f
8730e0f
c7acb7c
82d1b8d
afc927f
3a5396d
284c7c1
666b443
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,5 @@ | |
::: pydantic_ai.providers.grok | ||
|
||
::: pydantic_ai.providers.together | ||
|
||
::: pydantic_ai.providers.cerebras |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import warnings | ||
|
||
from pydantic_ai.exceptions import UserError | ||
|
||
from ._json_schema import JsonSchema, JsonSchemaTransformer | ||
|
||
|
||
class CerebrasJsonSchemaTransformer(JsonSchemaTransformer): | ||
"""Transforms the JSON Schema from Pydantic to be suitable for Cerebras. | ||
|
||
Cerebras supports a subset of OpenAI's structured output capabilities, which is documented here: | ||
- https://inference-docs.cerebras.ai/capabilities/structured-outputs#advanced-schema-features | ||
- https://inference-docs.cerebras.ai/capabilities/structured-outputs#variations-from-openai's-structured-output-capabilities | ||
- https://inference-docs.cerebras.ai/capabilities/tool-use | ||
|
||
TODO: `transform` method is based on GoogleJsonSchemaTransformer, and it doesn't handle all cases mentioned in links above. | ||
""" | ||
|
||
def __init__(self, schema: JsonSchema, *, strict: bool | None = None): | ||
super().__init__(schema, strict=strict, prefer_inlined_defs=True, simplify_nullable_unions=False) | ||
|
||
def transform(self, schema: JsonSchema) -> JsonSchema: | ||
additional_properties = schema.pop( | ||
'additionalProperties', None | ||
) # don't pop yet so it's included in the warning | ||
if additional_properties: | ||
original_schema = {**schema, 'additionalProperties': additional_properties} | ||
warnings.warn( | ||
'`additionalProperties` is not supported by Cerebras; it will be removed from the tool JSON schema.' | ||
f' Full schema: {self.schema}\n\n' | ||
f'Source of additionalProperties within the full schema: {original_schema}\n\n' | ||
'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n' | ||
"If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub" | ||
' and we will fix this behavior.', | ||
UserWarning, | ||
) | ||
|
||
schema.pop('title', None) | ||
schema.pop('default', None) | ||
schema.pop('$schema', None) | ||
if (const := schema.pop('const', None)) is not None: # pragma: no cover | ||
schema['enum'] = [const] | ||
schema.pop('discriminator', None) | ||
schema.pop('examples', None) | ||
|
||
# TODO: Should we use the trick from pydantic_ai.models.openai._OpenAIJsonSchema | ||
# where we add notes about these properties to the field description? | ||
schema.pop('exclusiveMaximum', None) | ||
schema.pop('exclusiveMinimum', None) | ||
|
||
# Pydantic will take care of transforming the transformed string values to the correct type. | ||
if enum := schema.get('enum'): | ||
schema['type'] = 'string' | ||
schema['enum'] = [str(val) for val in enum] | ||
|
||
type_ = schema.get('type') | ||
if 'oneOf' in schema and 'type' not in schema: # pragma: no cover | ||
# This gets hit when we have a discriminated union | ||
# Changing the oneOf to an anyOf prevents the API error and I think is functionally equivalent | ||
schema['anyOf'] = schema.pop('oneOf') | ||
|
||
if type_ == 'string' and (fmt := schema.pop('format', None)): | ||
description = schema.get('description') | ||
if description: | ||
schema['description'] = f'{description} (format: {fmt})' | ||
else: | ||
schema['description'] = f'Format: {fmt}' | ||
|
||
if '$ref' in schema: | ||
raise UserError(f'Recursive `$ref`s in JSON Schema are not supported by Cerebras: {schema["$ref"]}') | ||
|
||
if 'prefixItems' in schema: | ||
# prefixItems is not currently supported in Cerebras, so we convert it to items for best compatibility | ||
prefix_items = schema.pop('prefixItems') | ||
items = schema.get('items') | ||
unique_items = [items] if items is not None else [] | ||
for item in prefix_items: | ||
if item not in unique_items: | ||
unique_items.append(item) | ||
if len(unique_items) > 1: # pragma: no cover | ||
schema['items'] = {'anyOf': unique_items} | ||
elif len(unique_items) == 1: # pragma: no branch | ||
schema['items'] = unique_items[0] | ||
schema.setdefault('minItems', len(prefix_items)) | ||
if items is None: # pragma: no branch | ||
schema.setdefault('maxItems', len(prefix_items)) | ||
|
||
return schema |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import os | ||
from typing import overload | ||
|
||
from httpx import AsyncClient as AsyncHTTPClient | ||
from openai import AsyncOpenAI | ||
|
||
from pydantic_ai.exceptions import UserError | ||
from pydantic_ai.models import cached_async_http_client | ||
from pydantic_ai.profiles import ModelProfile | ||
from pydantic_ai.profiles.cerebras import CerebrasJsonSchemaTransformer | ||
from pydantic_ai.profiles.deepseek import deepseek_model_profile | ||
from pydantic_ai.profiles.meta import meta_model_profile | ||
from pydantic_ai.profiles.openai import OpenAIModelProfile | ||
from pydantic_ai.profiles.qwen import qwen_model_profile | ||
from pydantic_ai.providers import Provider | ||
|
||
try: | ||
from openai import AsyncOpenAI | ||
except ImportError as _import_error: # pragma: no cover | ||
raise ImportError( | ||
'Please install the `openai` package to use the Cerebras provider, ' | ||
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' | ||
) from _import_error | ||
|
||
|
||
class CerebrasProvider(Provider[AsyncOpenAI]): | ||
"""Provider for Cerebras API.""" | ||
|
||
@property | ||
def name(self) -> str: | ||
return 'cerebras' | ||
|
||
@property | ||
def base_url(self) -> str: | ||
return 'https://api.cerebras.ai/v1' | ||
|
||
@property | ||
def client(self) -> AsyncOpenAI: | ||
return self._client | ||
|
||
def model_profile(self, model_name: str) -> ModelProfile | None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should do automatic model profile selection here based on the underlying model used, like we do in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit puzzled on provider, model_name = model_name.split('/', 1)
if provider in provider_to_profile:
profile = provider_to_profile[provider](model_name) in I dont fully grasp the idea behind what's meant to be a "provider" in this case. Should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @smallstepman What we're trying to do is to use the regular model names that Cerebras expects to choose the right model profile for the model in question, based on the prefix. In the case of |
||
provider_to_profile = { | ||
'deepseek': deepseek_model_profile, | ||
'qwen': qwen_model_profile, | ||
'llama': meta_model_profile, | ||
} | ||
profile = None | ||
|
||
try: | ||
model_provider = model_name.split('-')[0] | ||
for provider, profile_func in provider_to_profile.items(): | ||
if model_provider.startswith(provider): | ||
profile = profile_func(model_name) | ||
break | ||
except Exception as _: # pragma: no cover | ||
pass | ||
|
||
return OpenAIModelProfile( | ||
json_schema_transformer=CerebrasJsonSchemaTransformer, openai_supports_strict_tool_definition=True | ||
).update(profile) | ||
|
||
@overload | ||
def __init__(self) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, api_key: str) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ... | ||
|
||
@overload | ||
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ... | ||
|
||
def __init__( | ||
self, | ||
*, | ||
api_key: str | None = None, | ||
openai_client: AsyncOpenAI | None = None, | ||
http_client: AsyncHTTPClient | None = None, | ||
) -> None: | ||
api_key = api_key or os.getenv('CEREBRAS_API_KEY') | ||
if not api_key and openai_client is None: | ||
raise UserError( | ||
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)`' | ||
'to use the Cerebras provider.' | ||
) | ||
|
||
if openai_client is not None: | ||
self._client = openai_client | ||
elif http_client is not None: | ||
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) | ||
else: | ||
http_client = cached_async_http_client(provider='cerebras') | ||
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import re | ||
|
||
import httpx | ||
import pytest | ||
|
||
from pydantic_ai.exceptions import UserError | ||
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer | ||
from pydantic_ai.profiles.cerebras import CerebrasJsonSchemaTransformer | ||
from pydantic_ai.profiles.openai import OpenAIModelProfile | ||
|
||
from ..conftest import TestEnv, try_import | ||
|
||
with try_import() as imports_successful: | ||
import openai | ||
|
||
from pydantic_ai.providers.cerebras import CerebrasProvider | ||
|
||
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed') | ||
|
||
|
||
def test_cerebras_provider(): | ||
provider = CerebrasProvider(api_key='api-key') | ||
assert provider.name == 'cerebras' | ||
assert provider.base_url == 'https://api.cerebras.ai/v1' | ||
assert isinstance(provider.client, openai.AsyncOpenAI) | ||
assert provider.client.api_key == 'api-key' | ||
|
||
|
||
def test_cerebras_provider_need_api_key(env: TestEnv) -> None: | ||
env.remove('CEREBRAS_API_KEY') | ||
with pytest.raises( | ||
UserError, | ||
match=re.escape( | ||
'Set the `CEREBRAS_API_KEY` environment variable or pass it via `CerebrasProvider(api_key=...)`' | ||
'to use the Cerebras provider.' | ||
), | ||
): | ||
CerebrasProvider() | ||
|
||
|
||
def test_cerebras_provider_pass_http_client() -> None: | ||
http_client = httpx.AsyncClient() | ||
provider = CerebrasProvider(http_client=http_client, api_key='api-key') | ||
assert provider.client._client == http_client # type: ignore[reportPrivateUsage] | ||
|
||
|
||
def test_cerebras_pass_openai_client() -> None: | ||
openai_client = openai.AsyncOpenAI(api_key='api-key') | ||
provider = CerebrasProvider(openai_client=openai_client) | ||
assert provider.client == openai_client | ||
|
||
|
||
def test_cerebras_model_profile(): | ||
provider = CerebrasProvider(api_key='api-key') | ||
|
||
model = provider.model_profile('llama-4-scout-17b-16e-instruct') | ||
assert isinstance(model, OpenAIModelProfile) | ||
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
assert model.openai_supports_strict_tool_definition is True | ||
|
||
model = provider.model_profile('llama3.1-8b') | ||
assert isinstance(model, OpenAIModelProfile) | ||
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
assert model.openai_supports_strict_tool_definition is True | ||
|
||
model = provider.model_profile('llama3.3-70b') | ||
assert isinstance(model, OpenAIModelProfile) | ||
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
assert model.openai_supports_strict_tool_definition is True | ||
|
||
model = provider.model_profile('qwen-3-32b') | ||
assert isinstance(model, OpenAIModelProfile) | ||
assert model.json_schema_transformer == InlineDefsJsonSchemaTransformer | ||
assert model.openai_supports_strict_tool_definition is True | ||
|
||
model = provider.model_profile('deepseek-r1-distill-llama-70b') | ||
assert isinstance(model, OpenAIModelProfile) | ||
assert model.json_schema_transformer == CerebrasJsonSchemaTransformer | ||
assert model.openai_supports_strict_tool_definition is True | ||
|
||
model = provider.model_profile('new-non-existing-model') | ||
assert isinstance(model, OpenAIModelProfile) | ||
assert model.json_schema_transformer == CerebrasJsonSchemaTransformer | ||
assert model.openai_supports_strict_tool_definition is True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please reduce the duplication here with the
GoogleJsonSchemaTransformer
, perhaps by subclassing it?