diff --git a/docs/api/models/huggingface.md b/docs/api/models/huggingface.md new file mode 100644 index 000000000..72e78c4a3 --- /dev/null +++ b/docs/api/models/huggingface.md @@ -0,0 +1,7 @@ +# `pydantic_ai.models.huggingface` + +## Setup + +For details on how to set up authentication with this model, see [model configuration for Hugging Face](../../models/huggingface.md). + +::: pydantic_ai.models.huggingface diff --git a/docs/api/providers.md b/docs/api/providers.md index 926cf8e8b..8e808185a 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -29,3 +29,5 @@ ::: pydantic_ai.providers.heroku.HerokuProvider ::: pydantic_ai.providers.openrouter.OpenRouterProvider + +::: pydantic_ai.providers.huggingface.HuggingFaceProvider diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md new file mode 100644 index 000000000..e99a77f00 --- /dev/null +++ b/docs/models/huggingface.md @@ -0,0 +1,91 @@ +# Hugging Face + +## Install + +To use `HuggingFaceModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: + +```bash +pip/uv-add "pydantic-ai-slim[huggingface]" +``` + +## Configuration + +To use [HuggingFace](https://huggingface.co/) through their main API, go to +[Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, +and you can generate a Hugging Face access token here: https://huggingface.co/settings/tokens. + +## Hugging Face access token + +Once you have a Hugging Face access token, you can set it as an environment variable: + +```bash +export HF_TOKEN='hf_token' +``` + +You can then use [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] by name: + +```python +from pydantic_ai import Agent + +agent = Agent('huggingface:Qwen/Qwen3-235B-A22B') +... +``` + +Or initialise the model directly with just the model name: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B') +agent = Agent(model) +... +``` + +By default, the [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] uses the +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] that will select automatically +the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your +preferred order in https://hf.co/settings/inference-providers. + +## Configure the provider + +If you want to pass parameters in code to the provider, you can programmatically instantiate the +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider='nebius')) +agent = Agent(model) +... +``` + +## Custom Hugging Face client + +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] also accepts a custom +[`AsyncInferenceClient`][huggingface_hub.AsyncInferenceClient] client via the `hf_client` parameter, so you can customise +the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the +[Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). + +```python +from huggingface_hub import AsyncInferenceClient + +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +client = AsyncInferenceClient( + bill_to='openai', + api_key='hf_token', + provider='fireworks-ai', +) + +model = HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', + provider=HuggingFaceProvider(hf_client=client), +) +agent = Agent(model) +... +``` diff --git a/mkdocs.yml b/mkdocs.yml index d750c29bb..55fd86384 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - api/models/gemini.md - api/models/google.md - api/models/groq.md + - api/models/huggingface.md - api/models/instrumented.md - api/models/mistral.md - api/models/test.md diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 71fa7a188..87c4db7d1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -227,6 +227,14 @@ 'heroku:claude-3-7-sonnet', 'heroku:claude-4-sonnet', 'heroku:claude-3-haiku', + 'huggingface:Qwen/QwQ-32B', + 'huggingface:Qwen/Qwen2.5-72B-Instruct', + 'huggingface:Qwen/Qwen3-235B-A22B', + 'huggingface:Qwen/Qwen3-32B', + 'huggingface:deepseek-ai/DeepSeek-R1', + 'huggingface:meta-llama/Llama-3.3-70B-Instruct', + 'huggingface:meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct', 'mistral:codestral-latest', 'mistral:mistral-large-latest', 'mistral:mistral-moderation-latest', @@ -539,7 +547,7 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] -def infer_model(model: Model | KnownModelName | str) -> Model: +def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 """Infer the model from the name.""" if isinstance(model, Model): return model @@ -593,6 +601,10 @@ def infer_model(model: Model | KnownModelName | str) -> Model: from .bedrock import BedrockConverseModel return BedrockConverseModel(model_name, provider=provider) + elif provider == 'huggingface': + from .huggingface import HuggingFaceModel + + return HuggingFaceModel(model_name, provider=provider) else: raise UserError(f'Unknown model: {model}') # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py new file mode 100644 index 000000000..cf17e104e --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -0,0 +1,461 @@ +from __future__ import annotations as _annotations + +import base64 +from collections.abc import AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Literal, Union, cast, overload + +from typing_extensions import assert_never + +from pydantic_ai.providers import Provider, infer_provider + +from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from ..messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, + RetryPromptPart, + SystemPromptPart, + TextPart, + ThinkingPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + VideoUrl, +) +from ..settings import ModelSettings +from ..tools import ToolDefinition +from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests + +try: + import aiohttp + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionInputMessageChunk, + ChatCompletionInputTool, + ChatCompletionInputToolCall, + ChatCompletionInputURL, + ChatCompletionOutput, + ChatCompletionOutputMessage, + ChatCompletionStreamOutput, + InferenceTimeoutError, + ) + from huggingface_hub.errors import HfHubHTTPError + +except ImportError as _import_error: + raise ImportError( + 'Please install `huggingface_hub` to use Hugging Face Inference Providers, ' + 'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`' + ) from _import_error + +__all__ = ( + 'HuggingFaceModel', + 'HuggingFaceModelSettings', +) + + +HFSystemPromptRole = Literal['system', 'user'] + +LatestHuggingFaceModelNames = Literal[ + 'deepseek-ai/DeepSeek-R1', + 'meta-llama/Llama-3.3-70B-Instruct', + 'meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'meta-llama/Llama-4-Scout-17B-16E-Instruct', + 'Qwen/QwQ-32B', + 'Qwen/Qwen2.5-72B-Instruct', + 'Qwen/Qwen3-235B-A22B', + 'Qwen/Qwen3-32B', +] +"""Latest Hugging Face models.""" + + +HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames] +"""Possible Hugging Face model names. + +You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). +""" + + +class HuggingFaceModelSettings(ModelSettings, total=False): + """Settings used for a Hugging Face model request. + + ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + """ + + # This class is a placeholder for any future huggingface-specific settings + + +@dataclass(init=False) +class HuggingFaceModel(Model): + """A model that uses Hugging Face Inference Providers. + + Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API. + + Apart from `__init__`, all methods are private or match those of the base class. + """ + + client: AsyncInferenceClient = field(repr=False) + + _model_name: str = field(repr=False) + _system: str = field(default='huggingface', repr=False) + + def __init__( + self, + model_name: str, + *, + provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface', + ): + """Initialize a Hugging Face model. + + Args: + model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). + provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an + instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used. + """ + self._model_name = model_name + self._provider = provider + if isinstance(provider, str): + provider = infer_provider(provider) + self.client = provider.client + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + check_allow_model_requests() + response = await self._completions_create( + messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + model_response = self._process_response(response) + model_response.usage.requests = 1 + return model_response + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[StreamedResponse]: + check_allow_model_requests() + response = await self._completions_create( + messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + yield await self._process_streamed_response(response) + + @property + def model_name(self) -> HuggingFaceModelName: + """The model name.""" + return self._model_name + + @property + def system(self) -> str: + """The system / model provider.""" + return self._system + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterable[ChatCompletionStreamOutput]: ... + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput: ... + + async def _completions_create( + self, + messages: list[ModelMessage], + stream: bool, + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]: + tools = self._get_tools(model_request_parameters) + + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif not model_request_parameters.allow_text_output: + tool_choice = 'required' + else: + tool_choice = 'auto' + + hf_messages = await self._map_messages(messages) + + try: + return await self.client.chat.completions.create( # type: ignore + model=self._model_name, + messages=hf_messages, # type: ignore + tools=tools, + tool_choice=tool_choice or None, + stream=stream, + stop=model_settings.get('stop_sequences', None), + temperature=model_settings.get('temperature', None), + top_p=model_settings.get('top_p', None), + seed=model_settings.get('seed', None), + presence_penalty=model_settings.get('presence_penalty', None), + frequency_penalty=model_settings.get('frequency_penalty', None), + logit_bias=model_settings.get('logit_bias', None), # type: ignore + logprobs=model_settings.get('logprobs', None), + top_logprobs=model_settings.get('top_logprobs', None), + extra_body=model_settings.get('extra_body'), # type: ignore + ) + except (InferenceTimeoutError, aiohttp.ClientResponseError, HfHubHTTPError) as e: + if isinstance(e, aiohttp.ClientResponseError): + raise ModelHTTPError( + status_code=e.status, + model_name=self.model_name, + body=e.response_error_payload, # type: ignore + ) from e + elif isinstance(e, HfHubHTTPError): + raise ModelHTTPError( + status_code=e.response.status_code, + model_name=self.model_name, + body=e.response.content, + ) from e + raise # pragma: lax no cover + + def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: + """Process a non-streamed response, and prepare a message to return.""" + if response.created: + timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) + else: + timestamp = _now_utc() + + choice = response.choices[0] + items: list[ModelResponsePart] = [] + + if choice.message.content is not None: + items.append(TextPart(choice.message.content)) + if choice.message.tool_calls is not None: + for c in choice.message.tool_calls: + items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) + return ModelResponse( + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + vendor_id=response.id, + ) + + async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse: + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): + raise UnexpectedModelBehavior( # pragma: no cover + 'Streamed response ended without content or tool calls' + ) + + return HuggingFaceStreamedResponse( + _model_name=self._model_name, + _response=peekable_response, + _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), + ) + + def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: + tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] + if model_request_parameters.output_tools: + tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] + return tools + + async def _map_messages( + self, messages: list[ModelMessage] + ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`.""" + hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = [] + for message in messages: + if isinstance(message, ModelRequest): + async for item in self._map_user_message(message): + hf_messages.append(item) + elif isinstance(message, ModelResponse): + texts: list[str] = [] + tool_calls: list[ChatCompletionInputToolCall] = [] + for item in message.parts: + if isinstance(item, ThinkingPart): + continue + if isinstance(item, TextPart): + texts.append(item.content) + elif isinstance(item, ToolCallPart): + tool_calls.append(self._map_tool_call(item)) + else: + assert_never(item) + message_param = ChatCompletionInputMessage(role='assistant') # type: ignore + if texts: + # Note: model responses from this model should only have one text item, so the following + # shouldn't merge multiple texts into one unless you switch models between runs: + message_param['content'] = '\n\n'.join(texts) + if tool_calls: + message_param['tool_calls'] = tool_calls + hf_messages.append(message_param) + else: + assert_never(message) + if instructions := self._get_instructions(messages): + hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore + return hf_messages + + @staticmethod + def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall: + return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore + { + 'id': _guard_tool_call_id(t=t), + 'type': 'function', + 'function': { + 'name': t.tool_name, + 'arguments': t.args_as_json_str(), + }, + } + ) + + @staticmethod + def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: + tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore + { + 'type': 'function', + 'function': { + 'name': f.name, + 'description': f.description, + 'parameters': f.parameters_json_schema, + }, + } + ) + if f.strict: + tool_param['function']['strict'] = f.strict + return tool_param + + async def _map_user_message( + self, message: ModelRequest + ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + for part in message.parts: + if isinstance(part, SystemPromptPart): + yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore + elif isinstance(part, UserPromptPart): + yield await self._map_user_prompt(part) + elif isinstance(part, ToolReturnPart): + yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response_str(), + } + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + {'role': 'user', 'content': part.model_response()} + ) + else: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response(), + } + ) + else: + assert_never(part) + + @staticmethod + async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: + content: str | list[ChatCompletionInputMessage] + if isinstance(part.content, str): + content = part.content + else: + content = [] + for item in part.content: + if isinstance(item, str): + content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore + elif isinstance(item, ImageUrl): + url = ChatCompletionInputURL(url=item.url) # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + elif isinstance(item, BinaryContent): + base64_encoded = base64.b64encode(item.data).decode('utf-8') + if item.is_image: + url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + else: # pragma: no cover + raise RuntimeError(f'Unsupported binary content type: {item.media_type}') + elif isinstance(item, AudioUrl): + raise NotImplementedError('AudioUrl is not supported for Hugging Face') + elif isinstance(item, DocumentUrl): + raise NotImplementedError('DocumentUrl is not supported for Hugging Face') + elif isinstance(item, VideoUrl): # pragma: no cover + raise NotImplementedError('VideoUrl is not supported for Hugging Face') + else: + assert_never(item) + return ChatCompletionInputMessage(role='user', content=content) # type: ignore + + +@dataclass +class HuggingFaceStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Hugging Face models.""" + + _model_name: str + _response: AsyncIterable[ChatCompletionStreamOutput] + _timestamp: datetime + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async for chunk in self._response: + self._usage += _map_usage(chunk) + + try: + choice = chunk.choices[0] + except IndexError: + continue + + # Handle the text part of the response + content = choice.delta.content + if content is not None: + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + + for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function and dtc.function.name, # type: ignore + args=dtc.function and dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self._model_name + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self._timestamp + + +def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage: + response_usage = response.usage + if response_usage is None: + return usage.Usage() + + return usage.Usage( + request_tokens=response_usage.prompt_tokens, + response_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, + details=None, + ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 3a6baba6e..3fdeec8a2 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .heroku import HerokuProvider return HerokuProvider + elif provider == 'huggingface': + from .huggingface import HuggingFaceProvider + + return HuggingFaceProvider else: # pragma: no cover raise ValueError(f'Unknown provider: {provider}') diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py new file mode 100644 index 000000000..e18a60d16 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -0,0 +1,74 @@ +from __future__ import annotations as _annotations + +import os + +from httpx import AsyncClient + +from pydantic_ai.exceptions import UserError + +try: + from huggingface_hub import AsyncInferenceClient +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `huggingface_hub` package to use the HuggingFace provider, ' + "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`" + ) from _import_error + +from . import Provider + + +class HuggingFaceProvider(Provider[AsyncInferenceClient]): + """Provider for Hugging Face.""" + + @property + def name(self) -> str: + return 'huggingface' + + @property + def base_url(self) -> str: + return self.client.model # type: ignore + + @property + def client(self) -> AsyncInferenceClient: + return self._client + + def __init__( + self, + base_url: str | None = None, + api_key: str | None = None, + hf_client: AsyncInferenceClient | None = None, + http_client: AsyncClient | None = None, + provider: str | None = None, + ) -> None: + """Create a new Hugging Face provider. + + Args: + base_url: The base url for the Hugging Face requests. + api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable + will be used if available. + hf_client: An existing + [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + client to use. If not provided, a new instance will be created. + http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests. + provider : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). + defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If `base_url` is passed, then `provider` is not used. + """ + api_key = api_key or os.environ.get('HF_TOKEN') + + if api_key is None: + raise UserError( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ) + + if http_client is not None: + raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead') + + if base_url is not None and provider is not None: + raise ValueError('Cannot provide both `base_url` and `provider`') + + if hf_client is None: + self._client = AsyncInferenceClient(api_key=api_key, provider=provider, base_url=base_url) # type: ignore + else: + self._client = hf_client diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index a04bd07c5..998dda16f 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,6 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.19.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.37.24"] +huggingface = ["huggingface-hub[inference]>=0.32.0"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] diff --git a/pyproject.toml b/pyproject.toml index b197af2ed..14d41d50c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals,a2a]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,a2a]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index 1bae76699..718f60f7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -291,6 +291,11 @@ def openrouter_api_key() -> str: return os.getenv('OPENROUTER_API_KEY', 'mock-api-key') +@pytest.fixture(scope='session') +def huggingface_api_key() -> str: + return os.getenv('HF_TOKEN', 'hf_token') or os.getenv('HUGGINGFACE_API_KEY', 'hf_token') + + @pytest.fixture(scope='session') def heroku_inference_key() -> str: return os.getenv('HEROKU_INFERENCE_KEY', 'mock-api-key') @@ -324,6 +329,7 @@ def model( groq_api_key: str, co_api_key: str, gemini_api_key: str, + huggingface_api_key: str, bedrock_provider: BedrockProvider, ) -> Model: # pragma: lax no cover try: @@ -366,6 +372,14 @@ def model( from pydantic_ai.models.bedrock import BedrockConverseModel return BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + elif request.param == 'huggingface': + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + return HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key), + ) else: raise ValueError(f'Unknown model: {request.param}') except ImportError: diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml new file mode 100644 index 000000000..11bcb7596 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -0,0 +1,59 @@ +interactions: +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '560' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Paris + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749475551 + id: chatcmpl-6fa46f85f4f04beda9c936d5996b22a8 + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 2 + completion_tokens_details: null + prompt_tokens: 26 + prompt_tokens_details: null + total_tokens: 28 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml new file mode 100644 index 000000000..6996da033 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml @@ -0,0 +1,126 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '800' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hf-inference: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '680' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with + anything specific. + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749475549 + id: chatcmpl-6050852c70164258bb9bab4e93e2b69c + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 29 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 59 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py new file mode 100644 index 000000000..384328adf --- /dev/null +++ b/tests/models/test_huggingface.py @@ -0,0 +1,690 @@ +from __future__ import annotations as _annotations + +import json +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timezone +from functools import cached_property +from typing import Any, Literal, Union, cast +from unittest.mock import Mock + +import pytest +from inline_snapshot import snapshot +from typing_extensions import TypedDict + +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior +from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.messages import ( + BinaryContent, + ImageUrl, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.result import Usage +from pydantic_ai.tools import RunContext + +from ..conftest import IsDatetime, IsNow, raise_if_exception, try_import +from .mock_async_stream import MockAsyncStream + +with try_import() as imports_successful: + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, + ) + from huggingface_hub.errors import HfHubHTTPError + + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + MockChatCompletion = Union[ChatCompletionOutput, Exception] + MockStreamEvent = Union[ChatCompletionStreamOutput, Exception] + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed'), + pytest.mark.anyio, +] + + +@dataclass +class MockHuggingFace: + completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None + stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] | None = None + index: int = 0 + + @cached_property + def chat(self) -> Any: + completions = type('Completions', (), {'create': self.chat_completions_create}) + return type('Chat', (), {'completions': completions}) + + @classmethod + def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(completions=completions)) + + @classmethod + def create_stream_mock( + cls, stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] + ) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(stream=stream)) + + async def chat_completions_create( + self, *_args: Any, stream: bool = False, **_kwargs: Any + ) -> ChatCompletionOutput | MockAsyncStream[MockStreamEvent]: + if stream or self.stream: + assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided' + if isinstance(self.stream[0], Sequence): + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream[self.index]))) + else: + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream))) + else: + assert self.completions is not None, 'you can only use `stream=False` if `completions` are provided' + if isinstance(self.completions, Sequence): + raise_if_exception(self.completions[self.index]) + response = cast(ChatCompletionOutput, self.completions[self.index]) + else: + raise_if_exception(self.completions) + response = cast(ChatCompletionOutput, self.completions) + self.index += 1 + return response + + +def completion_message( + message: ChatCompletionInputMessage | ChatCompletionOutputMessage, *, usage: ChatCompletionOutputUsage | None = None +) -> ChatCompletionOutput: + choices = [ChatCompletionOutputComplete(finish_reason='stop', index=0, message=message)] # type:ignore + return ChatCompletionOutput.parse_obj_as_instance( # type: ignore + { + 'id': '123', + 'choices': choices, + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion', + 'usage': usage, + } + ) + + +async def test_simple_completion(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model) + + result = await agent.run('hello') + assert result.output == 'world' + messages = result.all_messages() + request = messages[0] + response = messages[1] + assert request.parts[0].content == 'hello' # type: ignore + assert response == ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_request_simple_usage(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model) + + result = await agent.run('Hello') + assert result.output == 'world' + assert result.usage() == snapshot(Usage(requests=1)) + + +async def test_request_structured_response(allow_model_requests: None): + tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'final_result', + 'arguments': '{"response": [1, 2, 123]}', + } + ), + 'id': '123', + 'type': 'function', + } + ) + message = ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call], + } + ) + c = completion_message(message) + + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model, output_type=list[int]) + + result = await agent.run('Hello') + assert result.output == [1, 2, 123] + messages = result.all_messages() + assert messages[0].parts[0].content == 'Hello' # type: ignore + assert messages[1] == ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"response": [1, 2, 123]}', + tool_call_id='123', + ) + ], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_stream_completion(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world', finish_reason='stop')] + mock_client = MockHuggingFace.create_stream_mock(stream) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + async with agent.run_stream('') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + + +async def test_request_tool_call(allow_model_requests: None): + tool_call_1 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "San Fransisco"}', + } + ), + 'id': '1', + 'type': 'function', + } + ) + usage_1 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 1, + 'completion_tokens': 1, + 'total_tokens': 2, + } + ) + tool_call_2 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "London"}', + } + ), + 'id': '2', + 'type': 'function', + } + ) + usage_2 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 2, + 'completion_tokens': 1, + 'total_tokens': 3, + } + ) + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_1], + } + ), + usage=usage_1, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_2], + } + ), + usage=usage_2, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': 'final response', + 'role': 'assistant', + } + ), + ), + ] + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model, system_prompt='this is the system prompt') + + @agent.tool_plain + async def get_location(loc_name: str) -> str: + if loc_name == 'London': + return json.dumps({'lat': 51, 'lng': 0}) + else: + raise ModelRetry('Wrong location, please try again') + + result = await agent.run('Hello') + assert result.output == 'final response' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "San Fransisco"}', + tool_call_id='1', + ) + ], + usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrong location, please try again', + tool_name='get_location', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "London"}', + tool_call_id='2', + ) + ], + usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='{"lat": 51, "lng": 0}', + tool_call_id='2', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='final response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] + + +def chunk( + delta: list[ChatCompletionStreamOutputDelta], finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return ChatCompletionStreamOutput.parse_obj_as_instance( # type: ignore + { + 'id': 'x', + 'choices': [ + ChatCompletionStreamOutputChoice(index=index, delta=delta, finish_reason=finish_reason) # type: ignore + for index, delta in enumerate(delta) + ], + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion.chunk', + 'usage': ChatCompletionStreamOutputUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), # type: ignore + } + ) + + +def text_chunk(text: str, finish_reason: FinishReason | None = None) -> ChatCompletionStreamOutput: + return chunk([ChatCompletionStreamOutputDelta(content=text, role='assistant')], finish_reason=finish_reason) # type: ignore + + +async def test_stream_text(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world'), chunk([])] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_stream_text_finish_reason(allow_model_requests: None): + stream = [ + text_chunk('hello '), + text_chunk('world'), + text_chunk('.', finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) + assert result.is_complete + + +def struc_chunk( + tool_name: str | None, tool_arguments: str | None, finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return chunk( + [ + ChatCompletionStreamOutputDelta.parse_obj_as_instance( # type: ignore + { + 'role': 'assistant', + 'tool_calls': [ + ChatCompletionStreamOutputDeltaToolCall.parse_obj_as_instance( # type: ignore + { + 'index': 0, + 'function': ChatCompletionStreamOutputFunction.parse_obj_as_instance( # type: ignore + { + 'name': tool_name, + 'arguments': tool_arguments, + } + ), + } + ) + ], + } + ), + ], + finish_reason=finish_reason, + ) + + +class MyTypedDict(TypedDict, total=False): + first: str + second: str + + +async def test_stream_structured(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant', tool_calls=[])]), # type: ignore + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk('final_result', None), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + chunk([]), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {}, + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) + + +async def test_stream_structured_finish_reason(allow_model_requests: None): + stream = [ + struc_chunk('final_result', None), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + struc_chunk(None, None, finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + + +async def test_no_content(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m, output_type=MyTypedDict) + + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): + async with agent.run_stream(''): + pass + + +async def test_no_delta(allow_model_requests: None): + stream = [ + chunk([]), + text_chunk('hello '), + text_chunk('world'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_image_url_input(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + result = await agent.run( + [ + 'hello', + ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), + ] + ) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'hello', + ImageUrl( + url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg' + ), + ], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +async def test_image_as_binary_content_input(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type: ignore + mock_client = MockHuggingFace.create_mock(c) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + base64_content = ( + b'/9j/4AAQSkZJRgABAQEAYABgAAD/4QBYRXhpZgAATU0AKgAAAAgAA1IBAAEAAAABAAAAPgIBAAEAAAABAAAARgMBAAEAAAABAAAA' + b'WgAAAAAAAAAE' + ) + + result = await agent.run(['hello', BinaryContent(data=base64_content, media_type='image/jpeg')]) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=['hello', BinaryContent(data=base64_content, media_type='image/jpeg')], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +def test_model_status_error(allow_model_requests: None) -> None: + error = HfHubHTTPError(message='test_error', response=Mock(status_code=500, content={'error': 'test error'})) + mock_client = MockHuggingFace.create_mock(error) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + assert str(exc_info.value) == snapshot("status_code: 500, model_name: not_a_model, body: {'error': 'test error'}") + + +@pytest.mark.vcr() +async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + ) + agent = Agent(m) + result = await agent.run('hello') + assert result.output == snapshot( + 'Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) + + +@pytest.mark.vcr() +async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + ) + + def simple_instructions(ctx: RunContext): + return 'You are a helpful assistant.' + + agent = Agent(m, instructions=simple_instructions) + + result = await agent.run('What is the capital of France?') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is the capital of France?', timestamp=IsDatetime())], + instructions='You are a helpful assistant.', + ), + ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=IsDatetime(), + vendor_id='chatcmpl-6fa46f85f4f04beda9c936d5996b22a8', + ), + ] + ) diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index 52a3397a4..db6f22cd8 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -16,6 +16,7 @@ from pydantic_ai.models.cohere import CohereModelName from pydantic_ai.models.gemini import GeminiModelName from pydantic_ai.models.groq import GroqModelName + from pydantic_ai.models.huggingface import HuggingFaceModelName from pydantic_ai.models.mistral import MistralModelName from pydantic_ai.models.openai import OpenAIModelName @@ -54,6 +55,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: ] bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] + huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] heroku_names = get_heroku_model_names() extra_names = ['test'] @@ -66,6 +68,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: + openai_names + bedrock_names + deepseek_names + + huggingface_names + heroku_names + extra_names ) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py new file mode 100644 index 000000000..970c9d636 --- /dev/null +++ b/tests/providers/test_huggingface.py @@ -0,0 +1,61 @@ +from __future__ import annotations as _annotations + +import re + +import httpx +import pytest + +from pydantic_ai.exceptions import UserError + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + from huggingface_hub import AsyncInferenceClient + + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed') + + +def test_huggingface_provider(): + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(api_key='api-key', hf_client=hf_client) + assert provider.name == 'huggingface' + assert isinstance(provider.client, AsyncInferenceClient) + assert provider.client.token == 'api-key' + + +def test_huggingface_provider_need_api_key(env: TestEnv) -> None: + env.remove('HF_TOKEN') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ), + ): + HuggingFaceProvider() + + +def test_huggingface_provider_pass_http_client() -> None: + http_client = httpx.AsyncClient() + with pytest.raises( + ValueError, + match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'), + ): + HuggingFaceProvider(http_client=http_client, api_key='api-key') + + +def test_huggingface_provider_pass_hf_client() -> None: + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(hf_client=hf_client, api_key='api-key') + assert provider.client == hf_client + + +def test_hf_provider_with_base_url() -> None: + # Test with environment variable for base_url + provider = HuggingFaceProvider( + hf_client=AsyncInferenceClient(base_url='https://router.huggingface.co/nebius/v1'), api_key='test-api-key' + ) + assert provider.base_url == 'https://router.huggingface.co/nebius/v1' diff --git a/tests/test_cli.py b/tests/test_cli.py index 024116249..8efc0da00 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -144,6 +144,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'cohere', 'deepseek', 'heroku', + 'huggingface', ) models = {line.strip().split(' ')[0] for line in output[3:]} for provider in providers: diff --git a/tests/test_examples.py b/tests/test_examples.py index 214dcd419..1a26ec003 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -149,6 +149,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('CO_API_KEY', 'testing') env.set('MISTRAL_API_KEY', 'testing') env.set('ANTHROPIC_API_KEY', 'testing') + env.set('HF_TOKEN', 'hf_testing') env.set('AWS_ACCESS_KEY_ID', 'testing') env.set('AWS_SECRET_ACCESS_KEY', 'testing') env.set('AWS_DEFAULT_REGION', 'us-east-1') diff --git a/uv.lock b/uv.lock index 24f7bb52c..b0473976e 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,6 +1312,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload-time = "2022-09-25T15:39:59.68Z" }, ] +[[package]] +name = "hf-xet" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/dc/dc091aeeb671e71cbec30e84963f9c0202c17337b24b0a800e7d205543e8/hf_xet-1.1.3.tar.gz", hash = "sha256:a5f09b1dd24e6ff6bcedb4b0ddab2d81824098bb002cf8b4ffa780545fa348c3", size = 488127, upload-time = "2025-06-04T00:47:27.456Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/1f/bc01a4c0894973adebbcd4aa338a06815c76333ebb3921d94dcbd40dae6a/hf_xet-1.1.3-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c3b508b5f583a75641aebf732853deb058953370ce8184f5dabc49f803b0819b", size = 2256929, upload-time = "2025-06-04T00:47:21.206Z" }, + { url = "https://files.pythonhosted.org/packages/78/07/6ef50851b5c6b45b77a6e018fa299c69a2db3b8bbd0d5af594c0238b1ceb/hf_xet-1.1.3-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:b788a61977fbe6b5186e66239e2a329a3f0b7e7ff50dad38984c0c74f44aeca1", size = 2153719, upload-time = "2025-06-04T00:47:19.302Z" }, + { url = "https://files.pythonhosted.org/packages/52/48/e929e6e3db6e4758c2adf0f2ca2c59287f1b76229d8bdc1a4c9cfc05212e/hf_xet-1.1.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd2da210856444a34aad8ada2fc12f70dabed7cc20f37e90754d1d9b43bc0534", size = 4820519, upload-time = "2025-06-04T00:47:17.244Z" }, + { url = "https://files.pythonhosted.org/packages/28/2e/03f89c5014a5aafaa9b150655f811798a317036646623bdaace25f485ae8/hf_xet-1.1.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8203f52827e3df65981984936654a5b390566336956f65765a8aa58c362bb841", size = 4964121, upload-time = "2025-06-04T00:47:15.17Z" }, + { url = "https://files.pythonhosted.org/packages/47/8b/5cd399a92b47d98086f55fc72d69bc9ea5e5c6f27a9ed3e0cdd6be4e58a3/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:30c575a5306f8e6fda37edb866762140a435037365eba7a17ce7bd0bc0216a8b", size = 5283017, upload-time = "2025-06-04T00:47:23.239Z" }, + { url = "https://files.pythonhosted.org/packages/53/e3/2fcec58d2fcfd25ff07feb876f466cfa11f8dcf9d3b742c07fe9dd51ee0a/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7c1a6aa6abed1f696f8099aa9796ca04c9ee778a58728a115607de9cc4638ff1", size = 4970349, upload-time = "2025-06-04T00:47:25.383Z" }, + { url = "https://files.pythonhosted.org/packages/53/bf/10ca917e335861101017ff46044c90e517b574fbb37219347b83be1952f6/hf_xet-1.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:b578ae5ac9c056296bb0df9d018e597c8dc6390c5266f35b5c44696003cde9f3", size = 2310934, upload-time = "2025-06-04T00:47:29.632Z" }, +] + [[package]] name = "httpcore" version = "1.0.7" @@ -1351,20 +1366,26 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.29.1" +version = "0.32.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776, upload-time = "2025-02-20T09:24:59.839Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494, upload-time = "2025-06-03T09:59:46.105Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049, upload-time = "2025-02-20T09:24:57.962Z" }, + { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101, upload-time = "2025-06-03T09:59:44.099Z" }, +] + +[package.optional-dependencies] +inference = [ + { name = "aiohttp" }, ] [[package]] @@ -2896,7 +2917,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -2930,7 +2951,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["examples", "logfire"] @@ -3029,6 +3050,9 @@ google = [ groq = [ { name = "groq" }, ] +huggingface = [ + { name = "huggingface-hub", extra = ["inference"] }, +] logfire = [ { name = "logfire" }, ] @@ -3084,6 +3108,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.19.0" }, { name = "httpx", specifier = ">=0.27" }, + { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.4" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, @@ -3098,7 +3123,7 @@ requires-dist = [ { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [