Skip to content

Implement entrypoint typing #3387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions src/zenml/integrations/huggingface/steps/accelerate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,26 @@
#
"""Step function to run any ZenML step using Accelerate."""

from __future__ import annotations

import functools
from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast
from typing import (
Any,
Callable,
Dict,
Optional,
TypeVar,
Union,
cast,
overload,
)

import cloudpickle as pickle
from accelerate.commands.launch import (
launch_command,
launch_command_parser,
)
from typing_extensions import ParamSpec

from zenml import get_pipeline_context
from zenml.logger import get_logger
Expand All @@ -32,12 +44,27 @@

logger = get_logger(__name__)
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
R = TypeVar("R")


@overload
def run_with_accelerate(
**accelerate_launch_kwargs: Any,
) -> Callable[[BaseStep[P, R]], BaseStep[P, R]]: ...


@overload
def run_with_accelerate(
step_function_top_level: BaseStep[P, R],
/,
) -> BaseStep[P, R]: ...


def run_with_accelerate(
step_function_top_level: Optional[BaseStep] = None,
step_function_top_level: Optional[BaseStep[P, R]] = None,
**accelerate_launch_kwargs: Any,
) -> Union[Callable[[BaseStep], BaseStep], BaseStep]:
) -> Union[Callable[[BaseStep[P, R]], BaseStep[P, R]], BaseStep[P, R]]:
"""Run a function with accelerate.

Accelerate package: https://huggingface.co/docs/accelerate/en/index
Expand Down Expand Up @@ -70,9 +97,10 @@ def training_pipeline(some_param: int, ...):
The accelerate-enabled version of the step.
"""

def _decorator(step_function: BaseStep) -> BaseStep:
def _decorator(step_function: BaseStep[P, R]) -> BaseStep[P, R]:
def _wrapper(
entrypoint: F, accelerate_launch_kwargs: Dict[str, Any]
entrypoint: F,
accelerate_launch_kwargs: Dict[str, Any],
) -> F:
@functools.wraps(entrypoint)
def inner(*args: Any, **kwargs: Any) -> Any:
Expand Down
6 changes: 4 additions & 2 deletions src/zenml/integrations/whylogs/steps/whylogs_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# permissions and limitations under the License.
"""Implementation of the whylogs profiler step."""

from __future__ import annotations

import datetime
from typing import Optional, cast
from typing import Any, Optional, cast

import pandas as pd
from whylogs.core import DatasetProfileView # type: ignore
Expand Down Expand Up @@ -58,7 +60,7 @@ def get_whylogs_profiler_step(
dataset_timestamp: Optional[datetime.datetime] = None,
dataset_id: Optional[str] = None,
enable_whylabs: bool = True,
) -> BaseStep:
) -> BaseStep[..., Any]:
"""Shortcut function to create a new instance of the WhylogsProfilerStep step.

The returned WhylogsProfilerStep can be used in a pipeline to generate a
Expand Down
6 changes: 4 additions & 2 deletions src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# permissions and limitations under the License.
"""Utilities for creating step runs."""

from typing import Dict, List, Optional, Set, Tuple
from __future__ import annotations

from typing import Any, Dict, List, Optional, Set, Tuple

from zenml.client import Client
from zenml.config.step_configurations import Step
Expand Down Expand Up @@ -185,7 +187,7 @@ def _get_docstring_and_source_code_from_step_instance(
"""
from zenml.steps.base_step import BaseStep

step_instance = BaseStep.load_from_source(step.spec.source)
step_instance = BaseStep[Any, Any].load_from_source(step.spec.source)

docstring = step_instance.docstring
if docstring and len(docstring) > TEXT_FIELD_MAX_LENGTH:
Expand Down
8 changes: 6 additions & 2 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Class to run steps."""

from __future__ import annotations

import copy
import inspect
from contextlib import nullcontext
Expand Down Expand Up @@ -302,15 +304,17 @@ def _evaluate_artifact_names_in_collections(
for d in collections:
d[name] = d.pop(k)

def _load_step(self) -> "BaseStep":
def _load_step(self) -> "BaseStep[..., Any]":
"""Load the step instance.

Returns:
The step instance.
"""
from zenml.steps import BaseStep

step_instance = BaseStep.load_from_source(self._step.spec.source)
step_instance = BaseStep[Any, Any].load_from_source(
self._step.spec.source
)
step_instance = copy.deepcopy(step_instance)
step_instance._configuration = self._step.config
return step_instance
Expand Down
6 changes: 4 additions & 2 deletions src/zenml/pipelines/pipeline_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# permissions and limitations under the License.
"""Definition of a ZenML pipeline."""

from __future__ import annotations

import copy
import hashlib
import inspect
Expand Down Expand Up @@ -1138,7 +1140,7 @@ def _compute_unique_identifier(self, pipeline_spec: PipelineSpec) -> str:

def add_step_invocation(
self,
step: "BaseStep",
step: "BaseStep[..., Any]",
input_artifacts: Dict[str, StepArtifact],
external_artifacts: Dict[
str, Union["ExternalArtifact", "ArtifactVersionResponse"]
Expand Down Expand Up @@ -1208,7 +1210,7 @@ def add_step_invocation(

def _compute_invocation_id(
self,
step: "BaseStep",
step: "BaseStep[..., Any]",
custom_id: Optional[str] = None,
allow_suffix: bool = True,
) -> str:
Expand Down
21 changes: 13 additions & 8 deletions src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# permissions and limitations under the License.
"""Base Step for ZenML."""

from __future__ import annotations

import copy
import hashlib
import inspect
Expand All @@ -22,6 +24,7 @@
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Mapping,
Optional,
Expand All @@ -33,6 +36,7 @@
)

from pydantic import BaseModel, ConfigDict, ValidationError
from typing_extensions import ParamSpec

from zenml.client_lazy_loader import ClientLazyLoader
from zenml.config.retry_config import StepRetryConfig
Expand Down Expand Up @@ -91,10 +95,11 @@

logger = get_logger(__name__)

T = TypeVar("T", bound="BaseStep")
P = ParamSpec("P")
R = TypeVar("R")


class BaseStep:
class BaseStep(Generic[P, R]):
"""Abstract base class for all ZenML steps."""

def __init__(
Expand Down Expand Up @@ -212,7 +217,7 @@ def __init__(
notebook_utils.try_to_save_notebook_cell_code(self.source_object)

@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
def entrypoint(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Abstract method for core step logic.

Args:
Expand All @@ -224,7 +229,7 @@ def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""

@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep[P, R]":
"""Loads a step from source.

Args:
Expand Down Expand Up @@ -581,7 +586,7 @@ def configuration(self) -> "PartialStepConfiguration":
return self._configuration

def configure(
self: T,
self: "BaseStep[P,R]",
enable_cache: Optional[bool] = None,
enable_artifact_metadata: Optional[bool] = None,
enable_artifact_visualization: Optional[bool] = None,
Expand All @@ -600,7 +605,7 @@ def configure(
merge: bool = True,
retry: Optional[StepRetryConfig] = None,
substitutions: Optional[Dict[str, str]] = None,
) -> T:
) -> "BaseStep[P,R]":
"""Configures the step.

Configuration merging example:
Expand Down Expand Up @@ -733,7 +738,7 @@ def with_options(
model: Optional["Model"] = None,
merge: bool = True,
substitutions: Optional[Dict[str, str]] = None,
) -> "BaseStep":
) -> "BaseStep[P, R]":
"""Copies the step and applies the given configurations.

Args:
Expand Down Expand Up @@ -789,7 +794,7 @@ def with_options(
)
return step_copy

def copy(self) -> "BaseStep":
def copy(self) -> "BaseStep[P, R]":
"""Copies the step.

Returns:
Expand Down
11 changes: 9 additions & 2 deletions src/zenml/steps/decorated_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@
# permissions and limitations under the License.
"""Internal BaseStep subclass used by the step decorator."""

from typing import Any
from __future__ import annotations

from typing import Any, TypeVar

from typing_extensions import ParamSpec

from zenml.config.source import Source
from zenml.steps import BaseStep

P = ParamSpec("P")
R = TypeVar("R")


class _DecoratedStep(BaseStep):
class _DecoratedStep(BaseStep[P, R]):
"""Internal BaseStep subclass used by the step decorator."""

@property
Expand Down
19 changes: 12 additions & 7 deletions src/zenml/steps/step_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# permissions and limitations under the License.
"""Step decorator function."""

from __future__ import annotations

from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -27,6 +29,8 @@
overload,
)

from typing_extensions import ParamSpec

from zenml.logger import get_logger

if TYPE_CHECKING:
Expand All @@ -46,14 +50,15 @@
Mapping[str, MaterializerClassOrSource],
Mapping[str, Sequence[MaterializerClassOrSource]],
]
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
R = TypeVar("R")


logger = get_logger(__name__)


@overload
def step(_func: "F") -> "BaseStep": ...
def step(_func: "Callable[P, R]") -> "BaseStep[P, R]": ...


@overload
Expand All @@ -74,11 +79,11 @@ def step(
model: Optional["Model"] = None,
retry: Optional["StepRetryConfig"] = None,
substitutions: Optional[Dict[str, str]] = None,
) -> Callable[["F"], "BaseStep"]: ...
) -> Callable[["Callable[P, R]"], "BaseStep[P,R]"]: ...


def step(
_func: Optional["F"] = None,
_func: Optional["Callable[P, R]"] = None,
*,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
Expand All @@ -95,7 +100,7 @@ def step(
model: Optional["Model"] = None,
retry: Optional["StepRetryConfig"] = None,
substitutions: Optional[Dict[str, str]] = None,
) -> Union["BaseStep", Callable[["F"], "BaseStep"]]:
) -> Union["BaseStep[P,R]", Callable[["Callable[P, R]"], "BaseStep[P,R]"]]:
"""Decorator to create a ZenML step.

Args:
Expand Down Expand Up @@ -132,10 +137,10 @@ def step(
The step instance.
"""

def inner_decorator(func: "F") -> "BaseStep":
def inner_decorator(func: "Callable[P, R]") -> "BaseStep[P,R]":
from zenml.steps.decorated_step import _DecoratedStep

class_: Type["BaseStep"] = type(
class_: Type["BaseStep[P,R]"] = type(
func.__name__,
(_DecoratedStep,),
{
Expand Down
4 changes: 3 additions & 1 deletion src/zenml/steps/step_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# permissions and limitations under the License.
"""Step invocation class definition."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Set, Union

from zenml.models import ArtifactVersionResponse
Expand All @@ -33,7 +35,7 @@ class StepInvocation:
def __init__(
self,
id: str,
step: "BaseStep",
step: "BaseStep[..., Any]",
input_artifacts: Dict[str, "StepArtifact"],
external_artifacts: Dict[
str, Union["ExternalArtifact", "ArtifactVersionResponse"]
Expand Down
Loading
Loading