diff --git a/src/zenml/integrations/huggingface/steps/accelerate_runner.py b/src/zenml/integrations/huggingface/steps/accelerate_runner.py index c1cabc4a0f9..9d8a14d3973 100644 --- a/src/zenml/integrations/huggingface/steps/accelerate_runner.py +++ b/src/zenml/integrations/huggingface/steps/accelerate_runner.py @@ -16,14 +16,26 @@ # """Step function to run any ZenML step using Accelerate.""" +from __future__ import annotations + import functools -from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Dict, + Optional, + TypeVar, + Union, + cast, + overload, +) import cloudpickle as pickle from accelerate.commands.launch import ( launch_command, launch_command_parser, ) +from typing_extensions import ParamSpec from zenml import get_pipeline_context from zenml.logger import get_logger @@ -32,12 +44,27 @@ logger = get_logger(__name__) F = TypeVar("F", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") + + +@overload +def run_with_accelerate( + **accelerate_launch_kwargs: Any, +) -> Callable[[BaseStep[P, R]], BaseStep[P, R]]: ... + + +@overload +def run_with_accelerate( + step_function_top_level: BaseStep[P, R], + /, +) -> BaseStep[P, R]: ... def run_with_accelerate( - step_function_top_level: Optional[BaseStep] = None, + step_function_top_level: Optional[BaseStep[P, R]] = None, **accelerate_launch_kwargs: Any, -) -> Union[Callable[[BaseStep], BaseStep], BaseStep]: +) -> Union[Callable[[BaseStep[P, R]], BaseStep[P, R]], BaseStep[P, R]]: """Run a function with accelerate. Accelerate package: https://huggingface.co/docs/accelerate/en/index @@ -70,9 +97,10 @@ def training_pipeline(some_param: int, ...): The accelerate-enabled version of the step. """ - def _decorator(step_function: BaseStep) -> BaseStep: + def _decorator(step_function: BaseStep[P, R]) -> BaseStep[P, R]: def _wrapper( - entrypoint: F, accelerate_launch_kwargs: Dict[str, Any] + entrypoint: F, + accelerate_launch_kwargs: Dict[str, Any], ) -> F: @functools.wraps(entrypoint) def inner(*args: Any, **kwargs: Any) -> Any: diff --git a/src/zenml/integrations/whylogs/steps/whylogs_profiler.py b/src/zenml/integrations/whylogs/steps/whylogs_profiler.py index cad37bd91e1..1d284381a4d 100644 --- a/src/zenml/integrations/whylogs/steps/whylogs_profiler.py +++ b/src/zenml/integrations/whylogs/steps/whylogs_profiler.py @@ -13,8 +13,10 @@ # permissions and limitations under the License. """Implementation of the whylogs profiler step.""" +from __future__ import annotations + import datetime -from typing import Optional, cast +from typing import Any, Optional, cast import pandas as pd from whylogs.core import DatasetProfileView # type: ignore @@ -58,7 +60,7 @@ def get_whylogs_profiler_step( dataset_timestamp: Optional[datetime.datetime] = None, dataset_id: Optional[str] = None, enable_whylabs: bool = True, -) -> BaseStep: +) -> BaseStep[..., Any]: """Shortcut function to create a new instance of the WhylogsProfilerStep step. The returned WhylogsProfilerStep can be used in a pipeline to generate a diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 2177f3ce5d5..861f2d814c9 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -13,7 +13,9 @@ # permissions and limitations under the License. """Utilities for creating step runs.""" -from typing import Dict, List, Optional, Set, Tuple +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Set, Tuple from zenml.client import Client from zenml.config.step_configurations import Step @@ -185,7 +187,7 @@ def _get_docstring_and_source_code_from_step_instance( """ from zenml.steps.base_step import BaseStep - step_instance = BaseStep.load_from_source(step.spec.source) + step_instance = BaseStep[Any, Any].load_from_source(step.spec.source) docstring = step_instance.docstring if docstring and len(docstring) > TEXT_FIELD_MAX_LENGTH: diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index f18f1c649a2..b51fe307fe9 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -14,6 +14,8 @@ """Class to run steps.""" +from __future__ import annotations + import copy import inspect from contextlib import nullcontext @@ -302,7 +304,7 @@ def _evaluate_artifact_names_in_collections( for d in collections: d[name] = d.pop(k) - def _load_step(self) -> "BaseStep": + def _load_step(self) -> "BaseStep[..., Any]": """Load the step instance. Returns: @@ -310,7 +312,9 @@ def _load_step(self) -> "BaseStep": """ from zenml.steps import BaseStep - step_instance = BaseStep.load_from_source(self._step.spec.source) + step_instance = BaseStep[Any, Any].load_from_source( + self._step.spec.source + ) step_instance = copy.deepcopy(step_instance) step_instance._configuration = self._step.config return step_instance diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index 70c99837f28..29571f99f50 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """Definition of a ZenML pipeline.""" +from __future__ import annotations + import copy import hashlib import inspect @@ -1138,7 +1140,7 @@ def _compute_unique_identifier(self, pipeline_spec: PipelineSpec) -> str: def add_step_invocation( self, - step: "BaseStep", + step: "BaseStep[..., Any]", input_artifacts: Dict[str, StepArtifact], external_artifacts: Dict[ str, Union["ExternalArtifact", "ArtifactVersionResponse"] @@ -1208,7 +1210,7 @@ def add_step_invocation( def _compute_invocation_id( self, - step: "BaseStep", + step: "BaseStep[..., Any]", custom_id: Optional[str] = None, allow_suffix: bool = True, ) -> str: diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 55eaf54739b..275db195c6a 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """Base Step for ZenML.""" +from __future__ import annotations + import copy import hashlib import inspect @@ -22,6 +24,7 @@ TYPE_CHECKING, Any, Dict, + Generic, List, Mapping, Optional, @@ -33,6 +36,7 @@ ) from pydantic import BaseModel, ConfigDict, ValidationError +from typing_extensions import ParamSpec from zenml.client_lazy_loader import ClientLazyLoader from zenml.config.retry_config import StepRetryConfig @@ -91,10 +95,11 @@ logger = get_logger(__name__) -T = TypeVar("T", bound="BaseStep") +P = ParamSpec("P") +R = TypeVar("R") -class BaseStep: +class BaseStep(Generic[P, R]): """Abstract base class for all ZenML steps.""" def __init__( @@ -212,7 +217,7 @@ def __init__( notebook_utils.try_to_save_notebook_cell_code(self.source_object) @abstractmethod - def entrypoint(self, *args: Any, **kwargs: Any) -> Any: + def entrypoint(self, *args: P.args, **kwargs: P.kwargs) -> R: """Abstract method for core step logic. Args: @@ -224,7 +229,7 @@ def entrypoint(self, *args: Any, **kwargs: Any) -> Any: """ @classmethod - def load_from_source(cls, source: Union[Source, str]) -> "BaseStep": + def load_from_source(cls, source: Union[Source, str]) -> "BaseStep[P, R]": """Loads a step from source. Args: @@ -581,7 +586,7 @@ def configuration(self) -> "PartialStepConfiguration": return self._configuration def configure( - self: T, + self: "BaseStep[P,R]", enable_cache: Optional[bool] = None, enable_artifact_metadata: Optional[bool] = None, enable_artifact_visualization: Optional[bool] = None, @@ -600,7 +605,7 @@ def configure( merge: bool = True, retry: Optional[StepRetryConfig] = None, substitutions: Optional[Dict[str, str]] = None, - ) -> T: + ) -> "BaseStep[P,R]": """Configures the step. Configuration merging example: @@ -733,7 +738,7 @@ def with_options( model: Optional["Model"] = None, merge: bool = True, substitutions: Optional[Dict[str, str]] = None, - ) -> "BaseStep": + ) -> "BaseStep[P, R]": """Copies the step and applies the given configurations. Args: @@ -789,7 +794,7 @@ def with_options( ) return step_copy - def copy(self) -> "BaseStep": + def copy(self) -> "BaseStep[P, R]": """Copies the step. Returns: diff --git a/src/zenml/steps/decorated_step.py b/src/zenml/steps/decorated_step.py index a0f82e7651a..a99a3905069 100644 --- a/src/zenml/steps/decorated_step.py +++ b/src/zenml/steps/decorated_step.py @@ -13,13 +13,20 @@ # permissions and limitations under the License. """Internal BaseStep subclass used by the step decorator.""" -from typing import Any +from __future__ import annotations + +from typing import Any, TypeVar + +from typing_extensions import ParamSpec from zenml.config.source import Source from zenml.steps import BaseStep +P = ParamSpec("P") +R = TypeVar("R") + -class _DecoratedStep(BaseStep): +class _DecoratedStep(BaseStep[P, R]): """Internal BaseStep subclass used by the step decorator.""" @property diff --git a/src/zenml/steps/step_decorator.py b/src/zenml/steps/step_decorator.py index bd546c9e916..60bd1a40441 100644 --- a/src/zenml/steps/step_decorator.py +++ b/src/zenml/steps/step_decorator.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """Step decorator function.""" +from __future__ import annotations + from typing import ( TYPE_CHECKING, Any, @@ -27,6 +29,8 @@ overload, ) +from typing_extensions import ParamSpec + from zenml.logger import get_logger if TYPE_CHECKING: @@ -46,14 +50,15 @@ Mapping[str, MaterializerClassOrSource], Mapping[str, Sequence[MaterializerClassOrSource]], ] - F = TypeVar("F", bound=Callable[..., Any]) + P = ParamSpec("P") + R = TypeVar("R") logger = get_logger(__name__) @overload -def step(_func: "F") -> "BaseStep": ... +def step(_func: "Callable[P, R]") -> "BaseStep[P, R]": ... @overload @@ -74,11 +79,11 @@ def step( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, -) -> Callable[["F"], "BaseStep"]: ... +) -> Callable[["Callable[P, R]"], "BaseStep[P,R]"]: ... def step( - _func: Optional["F"] = None, + _func: Optional["Callable[P, R]"] = None, *, name: Optional[str] = None, enable_cache: Optional[bool] = None, @@ -95,7 +100,7 @@ def step( model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, -) -> Union["BaseStep", Callable[["F"], "BaseStep"]]: +) -> Union["BaseStep[P,R]", Callable[["Callable[P, R]"], "BaseStep[P,R]"]]: """Decorator to create a ZenML step. Args: @@ -132,10 +137,10 @@ def step( The step instance. """ - def inner_decorator(func: "F") -> "BaseStep": + def inner_decorator(func: "Callable[P, R]") -> "BaseStep[P,R]": from zenml.steps.decorated_step import _DecoratedStep - class_: Type["BaseStep"] = type( + class_: Type["BaseStep[P,R]"] = type( func.__name__, (_DecoratedStep,), { diff --git a/src/zenml/steps/step_invocation.py b/src/zenml/steps/step_invocation.py index 17341d40845..5fccd93987e 100644 --- a/src/zenml/steps/step_invocation.py +++ b/src/zenml/steps/step_invocation.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """Step invocation class definition.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Dict, Set, Union from zenml.models import ArtifactVersionResponse @@ -33,7 +35,7 @@ class StepInvocation: def __init__( self, id: str, - step: "BaseStep", + step: "BaseStep[..., Any]", input_artifacts: Dict[str, "StepArtifact"], external_artifacts: Dict[ str, Union["ExternalArtifact", "ArtifactVersionResponse"] diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index 780ccd62589..e54eebee79c 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -14,6 +14,8 @@ """Utility functions and classes to run ZenML steps.""" +from __future__ import annotations + import ast import contextlib import inspect @@ -25,12 +27,13 @@ Dict, Optional, Tuple, + TypeVar, Union, ) from uuid import UUID from pydantic import BaseModel -from typing_extensions import Annotated +from typing_extensions import Annotated, ParamSpec from zenml.artifacts.artifact_config import ArtifactConfig from zenml.client import Client @@ -49,6 +52,9 @@ if TYPE_CHECKING: from zenml.steps import BaseStep + P = ParamSpec("P") + R = TypeVar("R") + logger = get_logger(__name__) @@ -499,7 +505,7 @@ def log_step_metadata( def run_as_single_step_pipeline( - __step: "BaseStep", *args: Any, **kwargs: Any + __step: "BaseStep[P,R]", *args: Any, **kwargs: Any ) -> Any: """Runs the step as a single step pipeline.