From 0877a4b3cd1df296bf48ea9fb168ead6bbb5d701 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Jun 2025 15:26:32 +0200 Subject: [PATCH 01/17] Improved Kubernetes orchestrator pod caching --- .../kubernetes_orchestrator_entrypoint.py | 118 +++++++++++------- src/zenml/models/v2/core/pipeline_run.py | 9 ++ src/zenml/orchestrators/dag_runner.py | 13 +- src/zenml/orchestrators/step_run_utils.py | 46 ++++--- src/zenml/pipelines/run_utils.py | 4 +- .../schemas/pipeline_run_schemas.py | 42 +++++-- src/zenml/zen_stores/sql_zen_store.py | 58 ++++----- 7 files changed, 187 insertions(+), 103 deletions(-) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 64cb912eb20..020ecef563c 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -15,8 +15,7 @@ import argparse import socket -from typing import Any, Dict, cast -from uuid import UUID +from typing import Callable, Dict, Optional, cast from kubernetes import client as k8s_client @@ -41,10 +40,15 @@ from zenml.logger import get_logger from zenml.orchestrators import publish_utils from zenml.orchestrators.dag_runner import NodeStatus, ThreadedDagRunner +from zenml.orchestrators.step_run_utils import ( + StepRunRequestFactory, + publish_cached_step_run, +) from zenml.orchestrators.utils import ( get_config_environment_vars, get_orchestrator_run_name, ) +from zenml.pipelines.run_utils import create_placeholder_run logger = get_logger(__name__) @@ -86,20 +90,6 @@ def main() -> None: step_command = StepEntrypointConfiguration.get_entrypoint_command() - if args.run_id and not pipeline_settings.prevent_orchestrator_pod_caching: - from zenml.orchestrators import cache_utils - - run_required = ( - cache_utils.create_cached_step_runs_and_prune_deployment( - deployment=deployment, - pipeline_run=client.get_pipeline_run(args.run_id), - stack=active_stack, - ) - ) - - if not run_required: - return - mount_local_stores = active_stack.orchestrator.config.is_local # Get a Kubernetes client from the active Kubernetes orchestrator, but @@ -126,6 +116,55 @@ def main() -> None: for owner_reference in owner_references: owner_reference.controller = False + if args.run_id: + pipeline_run = client.get_pipeline_run(args.run_id) + else: + pipeline_run = create_placeholder_run( + deployment=deployment, + orchestrator_run_id=orchestrator_pod_name, + ) + assert pipeline_run + + pre_step_run: Optional[Callable[[str], bool]] = None + + if pipeline_settings.prevent_orchestrator_pod_caching: + step_run_request_factory = StepRunRequestFactory( + deployment=deployment, + pipeline_run=pipeline_run, + stack=active_stack, + ) + step_runs = {} + + def pre_step_run(step_name: str) -> bool: + """Pre-step run. + + Args: + step_name: Name of the step. + + Returns: + Whether the step node needs to be run. + """ + step_run_request = step_run_request_factory.create_request( + step_name + ) + try: + step_run_request_factory.populate_request(step_run_request) + except Exception as e: + logger.error( + f"Failed to populate step run request for step {step_name}: {e}" + ) + return True + + if step_run_request.status == ExecutionStatus.CACHED: + step_run = publish_cached_step_run( + step_run_request, pipeline_run + ) + step_runs[step_name] = step_run + logger.info("Using cached version of step `%s`.", step_name) + return False + + return True + def run_step_on_kubernetes(step_name: str) -> None: """Run a pipeline step in a separate Kubernetes pod. @@ -249,29 +288,6 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None: try: # Some steps may have failed because the pods could not be created. # We need to check for this and mark the step run as failed if so. - - # Fetch the pipeline run using any means possible. - list_args: Dict[str, Any] = {} - if args.run_id: - # For a run triggered outside of a schedule, we can use the - # placeholder run ID to find the pipeline run. - list_args = dict(id=UUID(args.run_id)) - else: - # For a run triggered by a schedule, we can only use the - # orchestrator run ID to find the pipeline run. - list_args = dict(orchestrator_run_id=orchestrator_pod_name) - - pipeline_runs = client.list_pipeline_runs( - hydrate=True, - project=deployment.project_id, - deployment_id=deployment.id, - **list_args, - ) - if not len(pipeline_runs): - # No pipeline run found, so we can't mark any step runs as failed. - return - - pipeline_run = pipeline_runs[0] pipeline_failed = False for step_name, node_state in node_states.items(): @@ -282,16 +298,21 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None: # If steps failed for any reason, we need to mark the step run as # failed, if it exists and it wasn't already in a final state. + step_runs = Client().list_run_steps( + size=1, + pipeline_run_id=pipeline_run.id, + name=step_name, + ) - step_run = pipeline_run.steps.get(step_name) - - # Try to update the step run status, if it exists and is in - # a transient state. - if step_run and step_run.status in { - ExecutionStatus.INITIALIZING, - ExecutionStatus.RUNNING, - }: - publish_utils.publish_failed_step_run(step_run.id) + if step_runs: + step_run = step_runs[0] + # Try to update the step run status, if it exists and is in + # a transient state. + if step_run and step_run.status in { + ExecutionStatus.INITIALIZING, + ExecutionStatus.RUNNING, + }: + publish_utils.publish_failed_step_run(step_run.id) # If any steps failed and the pipeline run is still in a transient # state, we need to mark it as failed. @@ -319,6 +340,7 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None: ThreadedDagRunner( dag=pipeline_dag, run_fn=run_step_on_kubernetes, + preparation_fn=pre_step_run, finalize_fn=finalize_run, parallel_node_startup_waiting_period=parallel_node_startup_waiting_period, max_parallelism=pipeline_settings.max_parallelism, diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 84c5ccd6d0e..988fc7b5345 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -125,6 +125,15 @@ class PipelineRunRequest(ProjectScopedRequest): title="Logs of the pipeline run.", ) + @property + def is_placeholder_request(self) -> bool: + """Whether the request is a placeholder request. + + Returns: + Whether the request is a placeholder request. + """ + return self.status == ExecutionStatus.INITIALIZING + model_config = ConfigDict(protected_namespaces=()) diff --git a/src/zenml/orchestrators/dag_runner.py b/src/zenml/orchestrators/dag_runner.py index e9234ea0c3d..a6a623eefb8 100644 --- a/src/zenml/orchestrators/dag_runner.py +++ b/src/zenml/orchestrators/dag_runner.py @@ -72,6 +72,7 @@ def __init__( self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any], + preparation_fn: Optional[Callable[[str], bool]] = None, finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None, parallel_node_startup_waiting_period: float = 0.0, max_parallelism: Optional[int] = None, @@ -83,6 +84,9 @@ def __init__( E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as `dag={2: [1], 3: [1], 4: [2, 3]}` run_fn: A function `run_fn(node)` that runs a single node + preparation_fn: A function that is called before the node is run. + If provided, the function return value determines whether the + node should be run or can be skipped. finalize_fn: A function `finalize_fn(node_states)` that is called when all nodes have completed. parallel_node_startup_waiting_period: Delay in seconds to wait in @@ -102,6 +106,7 @@ def __init__( self.dag = dag self.reversed_dag = reverse_dag(dag) self.run_fn = run_fn + self.preparation_fn = preparation_fn self.finalize_fn = finalize_fn self.nodes = dag.keys() self.node_states = { @@ -166,6 +171,12 @@ def _run_node(self, node: str) -> None: Args: node: The node. """ + if self.preparation_fn: + run_required = self.preparation_fn(node) + if not run_required: + self._finish_node(node) + return + self._prepare_node_run(node) try: @@ -203,8 +214,6 @@ def _finish_node(self, node: str, failed: bool = False) -> None: node: The node. failed: Whether the node failed. """ - # Update node status to completed. - assert self.node_states[node] == NodeStatus.RUNNING with self._lock: if failed: self.node_states[node] = NodeStatus.FAILED diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 55c83b07b17..8ed1291ca85 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -334,27 +334,15 @@ def create_cached_step_runs( # -> We don't need to do anything here continue - step_run = Client().zen_store.create_run_step(step_run_request) + step_run = publish_cached_step_run( + step_run_request, pipeline_run=pipeline_run + ) # Include the newly created step run in the step runs dictionary to # avoid fetching it again later when downstream steps need it for # input resolution. step_runs[invocation_id] = step_run - if ( - model_version := step_run.model_version - or pipeline_run.model_version - ): - link_output_artifacts_to_model_version( - artifacts=step_run.outputs, - model_version=model_version, - ) - - cascade_tags_for_output_artifacts( - artifacts=step_run.outputs, - tags=pipeline_run.config.tags, - ) - logger.info("Using cached version of step `%s`.", invocation_id) cached_invocations.add(invocation_id) @@ -427,3 +415,31 @@ def cascade_tags_for_output_artifacts( tags=[t.name for t in cascade_tags], artifact_version_id=output_artifact.id, ) + + +def publish_cached_step_run( + request: "StepRunRequest", pipeline_run: "PipelineRunResponse" +) -> "StepRunResponse": + """Create a cached step run and link to model version and tags. + + Args: + request: The request for the step run. + pipeline_run: The pipeline run of the step. + + Returns: + The createdstep run. + """ + step_run = Client().zen_store.create_run_step(request) + + if model_version := step_run.model_version or pipeline_run.model_version: + link_output_artifacts_to_model_version( + artifacts=step_run.outputs, + model_version=model_version, + ) + + cascade_tags_for_output_artifacts( + artifacts=step_run.outputs, + tags=pipeline_run.config.tags, + ) + + return step_run diff --git a/src/zenml/pipelines/run_utils.py b/src/zenml/pipelines/run_utils.py index 833ccae4310..af154b7bee9 100644 --- a/src/zenml/pipelines/run_utils.py +++ b/src/zenml/pipelines/run_utils.py @@ -50,6 +50,7 @@ def get_default_run_name(pipeline_name: str) -> str: def create_placeholder_run( deployment: "PipelineDeploymentResponse", + orchestrator_run_id: Optional[str] = None, logs: Optional["LogsRequest"] = None, ) -> Optional["PipelineRunResponse"]: """Create a placeholder run for the deployment. @@ -59,6 +60,7 @@ def create_placeholder_run( Args: deployment: The deployment for which to create the placeholder run. + orchestrator_run_id: The orchestrator run ID for the run. logs: The logs for the run. Returns: @@ -82,7 +84,7 @@ def create_placeholder_run( # the start_time is only set once the first step starts # running. start_time=start_time, - orchestrator_run_id=None, + orchestrator_run_id=orchestrator_run_id, project=deployment.project_id, deployment=deployment.id, pipeline=deployment.pipeline.id if deployment.pipeline else None, diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index c03fe593dc1..1d2cf98cc2b 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -51,7 +51,10 @@ from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema from zenml.zen_stores.schemas.project_schemas import ProjectSchema from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema -from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.schema_utils import ( + build_foreign_key_field, + build_index, +) from zenml.zen_stores.schemas.stack_schemas import StackSchema from zenml.zen_stores.schemas.trigger_schemas import TriggerExecutionSchema from zenml.zen_stores.schemas.user_schemas import UserSchema @@ -89,6 +92,10 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): "project_id", name="unique_run_name_in_project", ), + build_index( + table_name="pipeline_run", + column_names=["deployment_id", "status"], + ), ) # Fields @@ -533,8 +540,8 @@ def update_placeholder( Raises: RuntimeError: If the DB entry does not represent a placeholder run. - ValueError: If the run request does not match the deployment or - pipeline ID of the placeholder run. + ValueError: If the run request is not a valid request to replace the + placeholder run. Returns: The updated `PipelineRunSchema`. @@ -545,13 +552,33 @@ def update_placeholder( "placeholder run." ) + if request.status == ExecutionStatus.INITIALIZING: + raise ValueError( + "Cannot replace a placeholder run with another placeholder run." + ) + if ( self.deployment_id != request.deployment or self.pipeline_id != request.pipeline + or self.project_id != request.project + ): + raise ValueError( + "Deployment, project or pipeline ID of placeholder run " + "do not match the IDs of the run request." + ) + + if not request.orchestrator_run_id: + raise ValueError( + "Orchestrator run ID is required to replace a placeholder run." + ) + + if ( + self.orchestrator_run_id + and self.orchestrator_run_id != request.orchestrator_run_id ): raise ValueError( - "Deployment or orchestrator run ID of placeholder run do not " - "match the IDs of the run request." + "Orchestrator run ID of placeholder run does not match the " + "ID of the run request." ) orchestrator_environment = json.dumps(request.orchestrator_environment) @@ -570,7 +597,4 @@ def is_placeholder_run(self) -> bool: Returns: Whether the pipeline run is a placeholder run. """ - return ( - self.orchestrator_run_id is None - and self.status == ExecutionStatus.INITIALIZING - ) + return self.status == ExecutionStatus.INITIALIZING.value diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1ed43abf062..eda1bb4df5d 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5577,18 +5577,16 @@ def _replace_placeholder_run( # transaction to do so finishes. After the first transaction # finishes, the subsequent queries will not be able to find a # placeholder run anymore, as we already updated the - # orchestrator_run_id. + # status. # Note: This only locks a single row if the where clause of - # the query is indexed (we have a unique index due to the - # unique constraint on those columns). Otherwise, this will lock - # multiple rows or even the complete table which we want to - # avoid. + # the query is indexed. If you're modifying this, make sure to also + # update the index. Otherwise, this will lock multiple rows or even + # the complete table which we want to avoid. .with_for_update() .where(PipelineRunSchema.deployment_id == pipeline_run.deployment) .where( - PipelineRunSchema.orchestrator_run_id.is_(None) # type: ignore[union-attr] + PipelineRunSchema.status == ExecutionStatus.INITIALIZING.value ) - .where(PipelineRunSchema.project_id == pipeline_run.project) ).first() if not run_schema: @@ -5688,27 +5686,31 @@ def get_or_create_run( except KeyError: pass - try: - return ( - self._replace_placeholder_run( - pipeline_run=pipeline_run, - pre_replacement_hook=pre_creation_hook, - session=session, - ), - True, - ) - except KeyError: - # We were not able to find/replace a placeholder run. This could - # be due to one of the following three reasons: - # (1) There never was a placeholder run for the deployment. This - # is the case if the user ran the pipeline on a schedule. - # (2) There was a placeholder run, but a previous pipeline run - # already used it. This is the case if users rerun a - # pipeline run e.g. from the orchestrator UI, as they will - # use the same deployment_id with a new orchestrator_run_id. - # (3) A step of the same pipeline run already replaced the - # placeholder run. - pass + if not pipeline_run.is_placeholder_request: + # Only run this is the request is not a placeholder run itself, + # as we don't want to replace a placeholder run with another + # placeholder run. + try: + return ( + self._replace_placeholder_run( + pipeline_run=pipeline_run, + pre_replacement_hook=pre_creation_hook, + session=session, + ), + True, + ) + except KeyError: + # We were not able to find/replace a placeholder run. This could + # be due to one of the following three reasons: + # (1) There never was a placeholder run for the deployment. This + # is the case if the user ran the pipeline on a schedule. + # (2) There was a placeholder run, but a previous pipeline run + # already used it. This is the case if users rerun a + # pipeline run e.g. from the orchestrator UI, as they will + # use the same deployment_id with a new orchestrator_run_id. + # (3) A step of the same pipeline run already replaced the + # placeholder run. + pass try: # We now try to create a new run. The following will happen in From 30e212f2c544ba998fcd23817e41d7f31d611e7a Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Jun 2025 15:29:34 +0200 Subject: [PATCH 02/17] Enable for scheduled deployments --- .../kubernetes_orchestrator_entrypoint.py | 1 - src/zenml/pipelines/pipeline_definition.py | 8 ++++++-- src/zenml/pipelines/run_utils.py | 10 ++-------- src/zenml/zen_server/template_execution/utils.py | 1 - 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 020ecef563c..8b944ac1902 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -123,7 +123,6 @@ def main() -> None: deployment=deployment, orchestrator_run_id=orchestrator_pod_name, ) - assert pipeline_run pre_step_run: Optional[Callable[[str], bool]] = None diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index fcf452b8813..97d6151f4d5 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -863,8 +863,12 @@ def _run( deployment = self._create_deployment(**self._run_args) self.log_pipeline_deployment_metadata(deployment) - run = create_placeholder_run( - deployment=deployment, logs=logs_model + run = ( + create_placeholder_run( + deployment=deployment, logs=logs_model + ) + if not deployment.schedule + else None ) analytics_handler.metadata = ( diff --git a/src/zenml/pipelines/run_utils.py b/src/zenml/pipelines/run_utils.py index af154b7bee9..00c89fb82ab 100644 --- a/src/zenml/pipelines/run_utils.py +++ b/src/zenml/pipelines/run_utils.py @@ -52,23 +52,17 @@ def create_placeholder_run( deployment: "PipelineDeploymentResponse", orchestrator_run_id: Optional[str] = None, logs: Optional["LogsRequest"] = None, -) -> Optional["PipelineRunResponse"]: +) -> "PipelineRunResponse": """Create a placeholder run for the deployment. - If the deployment contains a schedule, no placeholder run will be - created. - Args: deployment: The deployment for which to create the placeholder run. orchestrator_run_id: The orchestrator run ID for the run. logs: The logs for the run. Returns: - The placeholder run or `None` if no run was created. + The placeholder run. """ - if deployment.schedule: - return None - start_time = utc_now() run_request = PipelineRunRequest( name=string_utils.format_name_template( diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index b5fbf3484a1..00f53f6e42d 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -193,7 +193,6 @@ def run_template( zenml_version = build.zenml_version placeholder_run = create_placeholder_run(deployment=new_deployment) - assert placeholder_run report_usage( feature=RUN_TEMPLATE_TRIGGERS_FEATURE_NAME, From 175e38ef9e61b024e970ccdc190f03a86b5b43eb Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Jun 2025 15:42:39 +0200 Subject: [PATCH 03/17] Improve placeholder run detection --- src/zenml/zen_stores/schemas/pipeline_run_schemas.py | 2 +- src/zenml/zen_stores/sql_zen_store.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 1d2cf98cc2b..f410c8d26be 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -94,7 +94,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ), build_index( table_name="pipeline_run", - column_names=["deployment_id", "status"], + column_names=["deployment_id", "orchestrator_run_id", "status"], ), ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index eda1bb4df5d..d02c3471484 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5584,6 +5584,13 @@ def _replace_placeholder_run( # the complete table which we want to avoid. .with_for_update() .where(PipelineRunSchema.deployment_id == pipeline_run.deployment) + .where( + or_( + PipelineRunSchema.orchestrator_run_id + == pipeline_run.orchestrator_run_id, + col(PipelineRunSchema.orchestrator_run_id).is_(None), + ) + ) .where( PipelineRunSchema.status == ExecutionStatus.INITIALIZING.value ) From 598b43683f44ac30653e94668c24697891573cf8 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Jun 2025 16:32:32 +0200 Subject: [PATCH 04/17] Store orchestrator run ID, improve labeling --- .../orchestrators/kubernetes_orchestrator.py | 40 ++++++++++++++++--- .../kubernetes_orchestrator_entrypoint.py | 30 +++++++------- ...s_orchestrator_entrypoint_configuration.py | 12 ------ .../orchestrators/manifest_utils.py | 26 +++--------- .../kubernetes_step_operator.py | 6 ++- src/zenml/zen_stores/sql_zen_store.py | 3 ++ 6 files changed, 63 insertions(+), 54 deletions(-) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index 4c0d7fc7b40..198fe06708a 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -47,6 +47,9 @@ from kubernetes import config as k8s_config from zenml.config.base_settings import BaseSettings +from zenml.constants import ( + METADATA_ORCHESTRATOR_RUN_ID, +) from zenml.enums import StackComponentType from zenml.integrations.kubernetes.flavors.kubernetes_orchestrator_flavor import ( KubernetesOrchestratorConfig, @@ -62,6 +65,7 @@ ) from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings from zenml.logger import get_logger +from zenml.metadata.metadata_types import MetadataType from zenml.orchestrators import ContainerizedOrchestrator from zenml.orchestrators.utils import get_orchestrator_run_name from zenml.stack import StackValidator @@ -460,9 +464,7 @@ def prepare_or_run_pipeline( # This will internally also build the command/args for all step pods. command = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command() args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments( - run_name=orchestrator_run_name, deployment_id=deployment.id, - kubernetes_namespace=self.config.kubernetes_namespace, run_id=placeholder_run.id if placeholder_run else None, ) @@ -501,6 +503,15 @@ def prepare_or_run_pipeline( } ) + orchestrator_pod_labels = { + "pipeline": kube_utils.sanitize_label(pipeline_name), + } + + if placeholder_run: + orchestrator_pod_labels["run_id"] = kube_utils.sanitize_label( + str(placeholder_run.id) + ) + # Schedule as CRON job if CRON schedule is given. if deployment.schedule: if not deployment.schedule.cron_expression: @@ -512,9 +523,7 @@ def prepare_or_run_pipeline( cron_expression = deployment.schedule.cron_expression cron_job_manifest = build_cron_job_manifest( cron_expression=cron_expression, - run_name=orchestrator_run_name, pod_name=pod_name, - pipeline_name=pipeline_name, image_name=image, command=command, args=args, @@ -526,6 +535,7 @@ def prepare_or_run_pipeline( successful_jobs_history_limit=settings.successful_jobs_history_limit, failed_jobs_history_limit=settings.failed_jobs_history_limit, ttl_seconds_after_finished=settings.ttl_seconds_after_finished, + labels=orchestrator_pod_labels, ) self._k8s_batch_api.create_namespaced_cron_job( @@ -540,9 +550,7 @@ def prepare_or_run_pipeline( else: # Create and run the orchestrator pod. pod_manifest = build_pod_manifest( - run_name=orchestrator_run_name, pod_name=pod_name, - pipeline_name=pipeline_name, image_name=image, command=command, args=args, @@ -550,6 +558,7 @@ def prepare_or_run_pipeline( pod_settings=orchestrator_pod_settings, service_account_name=service_account_name, env=environment, + labels=orchestrator_pod_labels, mount_local_stores=self.config.is_local, ) @@ -565,6 +574,10 @@ def prepare_or_run_pipeline( startup_timeout=settings.pod_startup_timeout, ) + yield { + METADATA_ORCHESTRATOR_RUN_ID: pod_name, + } + # Wait for the orchestrator pod to finish and stream logs. if settings.synchronous: logger.info( @@ -629,3 +642,18 @@ def get_orchestrator_run_id(self) -> str: "Unable to read run id from environment variable " f"{ENV_ZENML_KUBERNETES_RUN_ID}." ) + + def get_pipeline_run_metadata( + self, run_id: UUID + ) -> Dict[str, "MetadataType"]: + """Get general component-specific metadata for a pipeline run. + + Args: + run_id: The ID of the pipeline run. + + Returns: + A dictionary of metadata. + """ + return { + METADATA_ORCHESTRATOR_RUN_ID: self.get_orchestrator_run_id(), + } diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 8b944ac1902..9bdf25eb99d 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -60,9 +60,7 @@ def parse_args() -> argparse.Namespace: Parsed args. """ parser = argparse.ArgumentParser() - parser.add_argument("--run_name", type=str, required=True) parser.add_argument("--deployment_id", type=str, required=True) - parser.add_argument("--kubernetes_namespace", type=str, required=True) parser.add_argument("--run_id", type=str, required=False) return parser.parse_args() @@ -72,7 +70,6 @@ def main() -> None: # Log to the container's stdout so it can be streamed by the client. logger.info("Kubernetes orchestrator pod started.") - # Parse / extract args. args = parse_args() orchestrator_pod_name = socket.gethostname() @@ -81,6 +78,7 @@ def main() -> None: active_stack = client.active_stack orchestrator = active_stack.orchestrator assert isinstance(orchestrator, KubernetesOrchestrator) + namespace = orchestrator.config.kubernetes_namespace deployment = client.get_deployment(args.deployment_id) pipeline_settings = cast( @@ -105,7 +103,7 @@ def main() -> None: owner_references = kube_utils.get_pod_owner_references( core_api=core_api, pod_name=orchestrator_pod_name, - namespace=args.kubernetes_namespace, + namespace=namespace, ) except Exception as e: logger.warning(f"Failed to get pod owner references: {str(e)}") @@ -126,7 +124,7 @@ def main() -> None: pre_step_run: Optional[Callable[[str], bool]] = None - if pipeline_settings.prevent_orchestrator_pod_caching: + if not pipeline_settings.prevent_orchestrator_pod_caching: step_run_request_factory = StepRunRequestFactory( deployment=deployment, pipeline_run=pipeline_run, @@ -164,6 +162,13 @@ def pre_step_run(step_name: str) -> bool: return True + step_pod_labels = { + "run_id": kube_utils.sanitize_label(str(pipeline_run.id)), + "pipeline": kube_utils.sanitize_label( + deployment.pipeline_configuration.name + ), + } + def run_step_on_kubernetes(step_name: str) -> None: """Run a pipeline step in a separate Kubernetes pod. @@ -184,7 +189,7 @@ def run_step_on_kubernetes(step_name: str) -> None: ): max_length = ( kube_utils.calculate_max_pod_name_length_for_namespace( - namespace=args.kubernetes_namespace + namespace=namespace ) ) pod_name_prefix = get_orchestrator_run_name( @@ -194,9 +199,7 @@ def run_step_on_kubernetes(step_name: str) -> None: else: pod_name = f"{orchestrator_pod_name}-{step_name}" - pod_name = kube_utils.sanitize_pod_name( - pod_name, namespace=args.kubernetes_namespace - ) + pod_name = kube_utils.sanitize_pod_name(pod_name, namespace=namespace) image = KubernetesOrchestrator.get_image( deployment=deployment, step_name=step_name @@ -233,8 +236,6 @@ def run_step_on_kubernetes(step_name: str) -> None: # Define Kubernetes pod manifest. pod_manifest = build_pod_manifest( pod_name=pod_name, - run_name=args.run_name, - pipeline_name=deployment.pipeline_configuration.name, image_name=image, command=step_command, args=step_args, @@ -245,6 +246,7 @@ def run_step_on_kubernetes(step_name: str) -> None: or settings.service_account_name, mount_local_stores=mount_local_stores, owner_references=owner_references, + labels=step_pod_labels, ) kube_utils.create_and_wait_for_pod_to_start( @@ -252,7 +254,7 @@ def run_step_on_kubernetes(step_name: str) -> None: pod_display_name=f"pod for step `{step_name}`", pod_name=pod_name, pod_manifest=pod_manifest, - namespace=args.kubernetes_namespace, + namespace=namespace, startup_max_retries=settings.pod_failure_max_retries, startup_failure_delay=settings.pod_failure_retry_delay, startup_failure_backoff=settings.pod_failure_backoff, @@ -267,7 +269,7 @@ def run_step_on_kubernetes(step_name: str) -> None: incluster=True ), pod_name=pod_name, - namespace=args.kubernetes_namespace, + namespace=namespace, exit_condition_lambda=kube_utils.pod_is_done, stream_logs=True, ) @@ -354,7 +356,7 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None: try: kube_utils.delete_secret( core_api=core_api, - namespace=args.kubernetes_namespace, + namespace=namespace, secret_name=secret_name, ) except k8s_client.rest.ApiException as e: diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py index bdaec167318..7b7732fb311 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py @@ -18,9 +18,7 @@ if TYPE_CHECKING: from uuid import UUID -RUN_NAME_OPTION = "run_name" DEPLOYMENT_ID_OPTION = "deployment_id" -NAMESPACE_OPTION = "kubernetes_namespace" RUN_ID_OPTION = "run_id" @@ -35,9 +33,7 @@ def get_entrypoint_options(cls) -> Set[str]: Entrypoint options. """ options = { - RUN_NAME_OPTION, DEPLOYMENT_ID_OPTION, - NAMESPACE_OPTION, } return options @@ -58,29 +54,21 @@ def get_entrypoint_command(cls) -> List[str]: @classmethod def get_entrypoint_arguments( cls, - run_name: str, deployment_id: "UUID", - kubernetes_namespace: str, run_id: Optional["UUID"] = None, ) -> List[str]: """Gets all arguments that the entrypoint command should be called with. Args: - run_name: Name of the ZenML run. deployment_id: ID of the deployment. - kubernetes_namespace: Name of the Kubernetes namespace. run_id: Optional ID of the pipeline run. Not set for scheduled runs. Returns: List of entrypoint arguments. """ args = [ - f"--{RUN_NAME_OPTION}", - run_name, f"--{DEPLOYMENT_ID_OPTION}", str(deployment_id), - f"--{NAMESPACE_OPTION}", - kubernetes_namespace, ] if run_id: diff --git a/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py b/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py index d2d25125d3a..b33be6a17d6 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py +++ b/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py @@ -26,7 +26,6 @@ from zenml.integrations.airflow.orchestrators.dag_generator import ( ENV_ZENML_LOCAL_STORES_PATH, ) -from zenml.integrations.kubernetes.orchestrators import kube_utils from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings from zenml.logger import get_logger @@ -97,8 +96,6 @@ def add_local_stores_mount( def build_pod_manifest( pod_name: str, - run_name: str, - pipeline_name: str, image_name: str, command: List[str], args: List[str], @@ -106,6 +103,7 @@ def build_pod_manifest( pod_settings: Optional[KubernetesPodSettings] = None, service_account_name: Optional[str] = None, env: Optional[Dict[str, str]] = None, + labels: Optional[Dict[str, str]] = None, mount_local_stores: bool = False, owner_references: Optional[List[k8s_client.V1OwnerReference]] = None, ) -> k8s_client.V1Pod: @@ -113,8 +111,6 @@ def build_pod_manifest( Args: pod_name: Name of the pod. - run_name: Name of the ZenML run. - pipeline_name: Name of the ZenML pipeline. image_name: Name of the Docker image. command: Command to execute the entrypoint in the pod. args: Arguments provided to the entrypoint command. @@ -124,6 +120,7 @@ def build_pod_manifest( Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster. env: Environment variables to set. + labels: Labels to add to the pod. mount_local_stores: Whether to mount the local stores path inside the pod. owner_references: List of owner references for the pod. @@ -162,7 +159,7 @@ def build_pod_manifest( if service_account_name is not None: pod_spec.service_account_name = service_account_name - labels = {} + labels = labels or {} if pod_settings: add_pod_settings(pod_spec, pod_settings) @@ -171,14 +168,6 @@ def build_pod_manifest( if pod_settings.labels: labels.update(pod_settings.labels) - # Add run_name and pipeline_name to the labels - labels.update( - { - "run": kube_utils.sanitize_label(run_name), - "pipeline": kube_utils.sanitize_label(pipeline_name), - } - ) - pod_metadata = k8s_client.V1ObjectMeta( name=pod_name, labels=labels, @@ -272,8 +261,6 @@ def add_pod_settings( def build_cron_job_manifest( cron_expression: str, pod_name: str, - run_name: str, - pipeline_name: str, image_name: str, command: List[str], args: List[str], @@ -281,6 +268,7 @@ def build_cron_job_manifest( pod_settings: Optional[KubernetesPodSettings] = None, service_account_name: Optional[str] = None, env: Optional[Dict[str, str]] = None, + labels: Optional[Dict[str, str]] = None, mount_local_stores: bool = False, successful_jobs_history_limit: Optional[int] = None, failed_jobs_history_limit: Optional[int] = None, @@ -291,8 +279,6 @@ def build_cron_job_manifest( Args: cron_expression: CRON job schedule expression, e.g. "* * * * *". pod_name: Name of the pod. - run_name: Name of the ZenML run. - pipeline_name: Name of the ZenML pipeline. image_name: Name of the Docker image. command: Command to execute the entrypoint in the pod. args: Arguments provided to the entrypoint command. @@ -302,6 +288,7 @@ def build_cron_job_manifest( Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster. env: Environment variables to set. + labels: Labels to add to the pod. mount_local_stores: Whether to mount the local stores path inside the pod. successful_jobs_history_limit: The number of successful jobs to retain. @@ -314,8 +301,6 @@ def build_cron_job_manifest( """ pod_manifest = build_pod_manifest( pod_name=pod_name, - run_name=run_name, - pipeline_name=pipeline_name, image_name=image_name, command=command, args=args, @@ -323,6 +308,7 @@ def build_cron_job_manifest( pod_settings=pod_settings, service_account_name=service_account_name, env=env, + labels=labels, mount_local_stores=mount_local_stores, ) diff --git a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py index 7cc8c3ff6b2..76c978aecf3 100644 --- a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +++ b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py @@ -205,9 +205,7 @@ def launch( # Create and run the orchestrator pod. pod_manifest = build_pod_manifest( - run_name=info.run_name, pod_name=pod_name, - pipeline_name=info.pipeline.name, image_name=image_name, command=command, args=args, @@ -216,6 +214,10 @@ def launch( pod_settings=settings.pod_settings, env=environment, mount_local_stores=False, + labels={ + "run_id": kube_utils.sanitize_label(str(info.run_id)), + "pipeline": kube_utils.sanitize_label(info.pipeline.name), + }, ) kube_utils.create_and_wait_for_pod_to_start( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index d02c3471484..4c69383d4e3 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5641,6 +5641,9 @@ def _get_run_by_orchestrator_run_id( .where( PipelineRunSchema.orchestrator_run_id == orchestrator_run_id ) + .where( + PipelineRunSchema.status != ExecutionStatus.INITIALIZING.value + ) ).first() if not run_schema: From b1d69b21bc81945c5ab95916469a196b77ac550d Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Jun 2025 16:53:27 +0200 Subject: [PATCH 05/17] Index DB migration --- .../f9343f6633bd_add_pipeline_run_index.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py diff --git a/src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py b/src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py new file mode 100644 index 00000000000..5685a8c94d0 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py @@ -0,0 +1,39 @@ +"""Add pipeline run index [f9343f6633bd]. + +Revision ID: f9343f6633bd +Revises: 0.83.0 +Create Date: 2025-06-02 16:53:01.866526 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f9343f6633bd" +down_revision = "0.83.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_run", schema=None) as batch_op: + batch_op.create_index( + "ix_pipeline_run_deployment_id_orchestrator_run_id_status", + ["deployment_id", "orchestrator_run_id", "status"], + unique=False, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_run", schema=None) as batch_op: + batch_op.drop_index( + "ix_pipeline_run_deployment_id_orchestrator_run_id_status" + ) + + # ### end Alembic commands ### From a771454ee0a7d9844d4e2723d89bc09b254aadec Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Jun 2025 17:14:26 +0200 Subject: [PATCH 06/17] Docstring --- .../kubernetes/orchestrators/kubernetes_orchestrator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index 198fe06708a..3ed17ff9fb6 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -410,6 +410,9 @@ def prepare_or_run_pipeline( Raises: RuntimeError: If the Kubernetes orchestrator is not configured. + + Yields: + Metadata dictionary. """ for step_name, step in deployment.step_configurations.items(): if self.requires_resources_in_orchestration_environment(step): From b780b48ea90c40ee82e7e993e2456b6a14404381 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 3 Jun 2025 16:39:54 +0200 Subject: [PATCH 07/17] Remove unnecessary index --- .../f9343f6633bd_add_pipeline_run_index.py | 39 ------------------- .../schemas/pipeline_run_schemas.py | 5 --- src/zenml/zen_stores/sql_zen_store.py | 8 ++-- 3 files changed, 4 insertions(+), 48 deletions(-) delete mode 100644 src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py diff --git a/src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py b/src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py deleted file mode 100644 index 5685a8c94d0..00000000000 --- a/src/zenml/zen_stores/migrations/versions/f9343f6633bd_add_pipeline_run_index.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Add pipeline run index [f9343f6633bd]. - -Revision ID: f9343f6633bd -Revises: 0.83.0 -Create Date: 2025-06-02 16:53:01.866526 - -""" - -from alembic import op - -# revision identifiers, used by Alembic. -revision = "f9343f6633bd" -down_revision = "0.83.0" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - """Upgrade database schema and/or data, creating a new revision.""" - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("pipeline_run", schema=None) as batch_op: - batch_op.create_index( - "ix_pipeline_run_deployment_id_orchestrator_run_id_status", - ["deployment_id", "orchestrator_run_id", "status"], - unique=False, - ) - - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade database schema and/or data back to the previous revision.""" - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("pipeline_run", schema=None) as batch_op: - batch_op.drop_index( - "ix_pipeline_run_deployment_id_orchestrator_run_id_status" - ) - - # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index f410c8d26be..ed7b076a8b5 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -53,7 +53,6 @@ from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema from zenml.zen_stores.schemas.schema_utils import ( build_foreign_key_field, - build_index, ) from zenml.zen_stores.schemas.stack_schemas import StackSchema from zenml.zen_stores.schemas.trigger_schemas import TriggerExecutionSchema @@ -92,10 +91,6 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): "project_id", name="unique_run_name_in_project", ), - build_index( - table_name="pipeline_run", - column_names=["deployment_id", "orchestrator_run_id", "status"], - ), ) # Fields diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 4c69383d4e3..91bb50589b9 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5578,10 +5578,10 @@ def _replace_placeholder_run( # finishes, the subsequent queries will not be able to find a # placeholder run anymore, as we already updated the # status. - # Note: This only locks a single row if the where clause of - # the query is indexed. If you're modifying this, make sure to also - # update the index. Otherwise, this will lock multiple rows or even - # the complete table which we want to avoid. + # Note: Due to our unique index on deployment_id and + # orchestrator_run_id, this only locks a single row. If you're + # modifying this WHERE clause, make sure to test/adjust so this + # does not lock multiple rows or even the complete table. .with_for_update() .where(PipelineRunSchema.deployment_id == pipeline_run.deployment) .where( From 54be72d91cbba61450f5f648d64a245c4f4a8cdc Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 4 Jun 2025 10:15:31 +0200 Subject: [PATCH 08/17] Order placeholder runs --- src/zenml/zen_stores/sql_zen_store.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index da58ce785a0..4043220c284 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5598,6 +5598,12 @@ def _replace_placeholder_run( .where( PipelineRunSchema.status == ExecutionStatus.INITIALIZING.value ) + # In very rare cases, there can be multiple placeholder runs for + # the same deployment. By ordering by the orchestrator_run_id, we + # make sure that we use the placeholder run with the matching + # orchestrator_run_id if it exists, before falling back to the + # placeholder run without any orchestrator_run_id provided. + .order_by(col(PipelineRunSchema.orchestrator_run_id).nulls_last()) ).first() if not run_schema: From e81154d435bf40f7c9d0af0ef5993f5767fd1bc7 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 4 Jun 2025 18:17:15 +0200 Subject: [PATCH 09/17] Fix tests --- src/zenml/orchestrators/step_run_utils.py | 2 ++ .../kubernetes/orchestrators/test_manifest_utils.py | 7 +------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 8ed1291ca85..406d73c70c1 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -138,6 +138,8 @@ def populate_request( if cached_step_run := cache_utils.get_cached_step_run( cache_key=cache_key ): + # TODO: maybe also fetch the docstring/source code from the + # cached step run? request.inputs = { input_name: [artifact.id for artifact in artifacts] for input_name, artifacts in cached_step_run.inputs.items() diff --git a/tests/integration/integrations/kubernetes/orchestrators/test_manifest_utils.py b/tests/integration/integrations/kubernetes/orchestrators/test_manifest_utils.py index 9a4f8849b63..039eb54b8d8 100644 --- a/tests/integration/integrations/kubernetes/orchestrators/test_manifest_utils.py +++ b/tests/integration/integrations/kubernetes/orchestrators/test_manifest_utils.py @@ -34,8 +34,6 @@ def test_build_pod_manifest_metadata(): """Test that the metadata is correctly set in the manifest.""" manifest: V1Pod = build_pod_manifest( pod_name="test_name", - run_name="test_run", - pipeline_name="test_pipeline", image_name="test_image", command=["test", "command"], args=["test", "args"], @@ -43,6 +41,7 @@ def test_build_pod_manifest_metadata(): pod_settings=KubernetesPodSettings( annotations={"blupus_loves": "strawberries"}, ), + labels={"run": "test-run", "pipeline": "test-pipeline"}, ) assert isinstance(manifest, V1Pod) @@ -97,8 +96,6 @@ def test_build_pod_manifest_pod_settings( """Test that the pod settings are correctly set in the manifest.""" manifest: V1Pod = build_pod_manifest( pod_name="test_name", - run_name="test_run", - pipeline_name="test_pipeline", image_name="test_image", command=["test", "command"], args=["test", "args"], @@ -125,8 +122,6 @@ def test_build_cron_job_manifest_pod_settings( manifest: V1CronJob = build_cron_job_manifest( cron_expression="* * * * *", pod_name="test_name", - run_name="test_run", - pipeline_name="test_pipeline", image_name="test_image", command=["test", "command"], args=["test", "args"], From 6f3f45cccbc81355ceb4fef6e31ff42b6216812e Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 5 Jun 2025 09:19:32 +0200 Subject: [PATCH 10/17] Reuse docstring/source code from cache candidate --- src/zenml/orchestrators/step_run_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 406d73c70c1..0d69258dba8 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -138,8 +138,6 @@ def populate_request( if cached_step_run := cache_utils.get_cached_step_run( cache_key=cache_key ): - # TODO: maybe also fetch the docstring/source code from the - # cached step run? request.inputs = { input_name: [artifact.id for artifact in artifacts] for input_name, artifacts in cached_step_run.inputs.items() @@ -154,6 +152,15 @@ def populate_request( request.status = ExecutionStatus.CACHED request.end_time = request.start_time + # As a last resort, we try to reuse the docstring/source code + # from the cached step run. This is part of the cache key + # computation, so it must be identical to the one we would have + # computed ourselves. + if request.source_code is None: + request.source_code = cached_step_run.source_code + if request.docstring is None: + request.docstring = cached_step_run.docstring + def _get_docstring_and_source_code( self, invocation_id: str ) -> Tuple[Optional[str], Optional[str]]: From 6c8791ff3ab704510c9f8aeca49f9af45b87f24f Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 5 Jun 2025 09:29:14 +0200 Subject: [PATCH 11/17] Use more portable sorting --- src/zenml/zen_stores/sql_zen_store.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 8e7aaab0340..2446ece2e58 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5603,7 +5603,11 @@ def _replace_placeholder_run( # make sure that we use the placeholder run with the matching # orchestrator_run_id if it exists, before falling back to the # placeholder run without any orchestrator_run_id provided. - .order_by(col(PipelineRunSchema.orchestrator_run_id).nulls_last()) + # Note: This works because both SQLite and MySQL consider NULLs + # to be lower than any other value. If we add support for other + # databases (e.g. PostgreSQL, which considers NULLs to be greater + # than any other value), we need to potentially adjust this. + .order_by(desc(PipelineRunSchema.orchestrator_run_id)) ).first() if not run_schema: From c167a6993eb2779cf8a8d1b9566a53efbb7bce07 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Jun 2025 12:20:52 +0200 Subject: [PATCH 12/17] Formatting after merge --- .../orchestrators/kubernetes_orchestrator.py | 41 ++++++++----------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index 541848e55a1..c02995d6979 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -400,30 +400,23 @@ def submit_pipeline( ) -> Optional[SubmissionResult]: """Submits a pipeline to the orchestrator. - This method should only submit the pipeline and not wait for it to - complete. If the orchestrator is configured to wait for the pipeline run - to complete, a function that waits for the pipeline run to complete can - be passed as part of the submission result. - - Args: - deployment: The pipeline deployment to submit. - stack: The stack the pipeline will run on. - environment: Environment variables to set in the orchestration - environment. These don't need to be set if running locally. - placeholder_run: An optional placeholder run for the deployment. - - Raises: - <<<<<<< HEAD - RuntimeError: If the Kubernetes orchestrator is not configured. - - Yields: - Metadata dictionary. - ======= - RuntimeError: If a schedule without cron expression is given. - - Returns: - Optional submission result. - >>>>>>> develop + This method should only submit the pipeline and not wait for it to + complete. If the orchestrator is configured to wait for the pipeline run + to complete, a function that waits for the pipeline run to complete can + be passed as part of the submission result. + + Args: + deployment: The pipeline deployment to submit. + stack: The stack the pipeline will run on. + environment: Environment variables to set in the orchestration + environment. These don't need to be set if running locally. + placeholder_run: An optional placeholder run for the deployment. + + Raises: + RuntimeError: If a schedule without cron expression is given. + + Returns: + Optional submission result. """ for step_name, step in deployment.step_configurations.items(): if self.requires_resources_in_orchestration_environment(step): From 233e33f0051c5bc241bf22dbcce0f020efa1ded1 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 24 Jun 2025 13:29:32 +0200 Subject: [PATCH 13/17] Linting --- .../kubernetes/orchestrators/kubernetes_orchestrator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index c02995d6979..613b82e518b 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -581,7 +581,7 @@ def submit_pipeline( startup_timeout=settings.pod_startup_timeout, ) - metadata = { + metadata: Dict[str, MetadataType] = { METADATA_ORCHESTRATOR_RUN_ID: pod_name, } From 7dcf4766afe5e9c597d4af25043a55b043e90ee7 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 25 Jun 2025 13:13:07 +0200 Subject: [PATCH 14/17] Apply suggestions from code review Co-authored-by: Stefan Nica --- src/zenml/zen_stores/schemas/pipeline_run_schemas.py | 2 +- src/zenml/zen_stores/sql_zen_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 2ac6cf05980..2b67678bea1 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -564,7 +564,7 @@ def update_placeholder( "placeholder run." ) - if request.status == ExecutionStatus.INITIALIZING: + if request.is_placeholder_request: raise ValueError( "Cannot replace a placeholder run with another placeholder run." ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index ea4db736cc3..7ee27ea6b22 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5970,7 +5970,7 @@ def get_or_create_run( pass if not pipeline_run.is_placeholder_request: - # Only run this is the request is not a placeholder run itself, + # Only run this if the request is not a placeholder run itself, # as we don't want to replace a placeholder run with another # placeholder run. try: From ec5e7a3caf4304702cd508262fd3fd3571099986 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 25 Jun 2025 18:46:56 +0200 Subject: [PATCH 15/17] Add run name label --- .../kubernetes/orchestrators/kubernetes_orchestrator.py | 3 +++ .../orchestrators/kubernetes_orchestrator_entrypoint.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index 613b82e518b..c3e008fbe1b 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -518,6 +518,9 @@ def submit_pipeline( orchestrator_pod_labels["run_id"] = kube_utils.sanitize_label( str(placeholder_run.id) ) + orchestrator_pod_labels["run_name"] = kube_utils.sanitize_label( + str(placeholder_run.name) + ) # Schedule as CRON job if CRON schedule is given. if deployment.schedule: diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 9bdf25eb99d..2136ca57893 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -164,6 +164,7 @@ def pre_step_run(step_name: str) -> bool: step_pod_labels = { "run_id": kube_utils.sanitize_label(str(pipeline_run.id)), + "run_name": kube_utils.sanitize_label(str(pipeline_run.name)), "pipeline": kube_utils.sanitize_label( deployment.pipeline_configuration.name ), From 50a10985c9f03b1a2961a0482969be5fb5c21b31 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Wed, 25 Jun 2025 18:55:55 +0200 Subject: [PATCH 16/17] Fetch step runs for failed nodes in batches --- .../kubernetes_orchestrator_entrypoint.py | 20 ++++---- src/zenml/orchestrators/input_utils.py | 41 +++------------- src/zenml/orchestrators/step_run_utils.py | 49 +++++++++++++++++++ 3 files changed, 65 insertions(+), 45 deletions(-) diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index 2136ca57893..a6f6b99594c 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -42,6 +42,7 @@ from zenml.orchestrators.dag_runner import NodeStatus, ThreadedDagRunner from zenml.orchestrators.step_run_utils import ( StepRunRequestFactory, + fetch_step_runs_by_names, publish_cached_step_run, ) from zenml.orchestrators.utils import ( @@ -291,6 +292,14 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None: # Some steps may have failed because the pods could not be created. # We need to check for this and mark the step run as failed if so. pipeline_failed = False + failed_step_names = [ + step_name + for step_name, node_state in node_states.items() + if node_state == NodeStatus.FAILED + ] + step_runs = fetch_step_runs_by_names( + step_run_names=failed_step_names, pipeline_run=pipeline_run + ) for step_name, node_state in node_states.items(): if node_state != NodeStatus.FAILED: @@ -298,16 +307,7 @@ def finalize_run(node_states: Dict[str, NodeStatus]) -> None: pipeline_failed = True - # If steps failed for any reason, we need to mark the step run as - # failed, if it exists and it wasn't already in a final state. - step_runs = Client().list_run_steps( - size=1, - pipeline_run_id=pipeline_run.id, - name=step_name, - ) - - if step_runs: - step_run = step_runs[0] + if step_run := step_runs.get(step_name, None): # Try to update the step run status, if it exists and is in # a transient state. if step_run and step_run.status in { diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index eb634c285c7..9f8c5a0732b 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Utilities for inputs.""" -import json from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from uuid import UUID @@ -21,7 +20,7 @@ from zenml.config.step_configurations import Step from zenml.enums import StepRunInputArtifactType from zenml.exceptions import InputResolutionError -from zenml.utils import pagination_utils, string_utils +from zenml.utils import string_utils if TYPE_CHECKING: from zenml.models import PipelineRunResponse, StepRunResponse @@ -54,6 +53,7 @@ def resolve_step_inputs( """ from zenml.models import ArtifactVersionResponse from zenml.models.v2.core.step_run import StepRunInputResponse + from zenml.orchestrators.step_run_utils import fetch_step_runs_by_names step_runs = step_runs or {} @@ -65,40 +65,11 @@ def resolve_step_inputs( steps_to_fetch.difference_update(step_runs.keys()) if steps_to_fetch: - # The list of steps might be too big to fit in the default max URL - # length of 8KB supported by most servers. So we need to split it into - # smaller chunks. - steps_list = list(steps_to_fetch) - chunks = [] - current_chunk = [] - current_length = 0 - # stay under 6KB for good measure. - max_chunk_length = 6000 - - for step_name in steps_list: - current_chunk.append(step_name) - current_length += len(step_name) + 5 # 5 is for the JSON encoding - - if current_length > max_chunk_length: - chunks.append(current_chunk) - current_chunk = [] - current_length = 0 - - if current_chunk: - chunks.append(current_chunk) - - for chunk in chunks: - step_runs.update( - { - run_step.name: run_step - for run_step in pagination_utils.depaginate( - Client().list_run_steps, - pipeline_run_id=pipeline_run.id, - project=pipeline_run.project_id, - name="oneof:" + json.dumps(chunk), - ) - } + step_runs.update( + fetch_step_runs_by_names( + step_run_names=list(steps_to_fetch), pipeline_run=pipeline_run ) + ) input_artifacts: Dict[str, StepRunInputResponse] = {} for name, input_ in step.spec.inputs.items(): diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 0d69258dba8..9871932531f 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Utilities for creating step runs.""" +import json from typing import Dict, List, Optional, Set, Tuple, Union from zenml import Tag, add_tags @@ -32,6 +33,7 @@ ) from zenml.orchestrators import cache_utils, input_utils, utils from zenml.stack import Stack +from zenml.utils import pagination_utils from zenml.utils.time_utils import utc_now logger = get_logger(__name__) @@ -452,3 +454,50 @@ def publish_cached_step_run( ) return step_run + + +def fetch_step_runs_by_names( + step_run_names: List[str], pipeline_run: "PipelineRunResponse" +) -> Dict[str, "StepRunResponse"]: + """Fetch step runs by names. + + Args: + step_run_names: The names of the step runs to fetch. + pipeline_run: The pipeline run of the step runs. + + Returns: + A dictionary of step runs by name. + """ + step_runs = {} + + chunks = [] + current_chunk = [] + current_length = 0 + # stay under 6KB for good measure. + max_chunk_length = 6000 + + for step_name in step_run_names: + current_chunk.append(step_name) + current_length += len(step_name) + 5 # 5 is for the JSON encoding + + if current_length > max_chunk_length: + chunks.append(current_chunk) + current_chunk = [] + current_length = 0 + + if current_chunk: + chunks.append(current_chunk) + + for chunk in chunks: + step_runs.update( + { + run_step.name: run_step + for run_step in pagination_utils.depaginate( + Client().list_run_steps, + pipeline_run_id=pipeline_run.id, + project=pipeline_run.project_id, + name="oneof:" + json.dumps(chunk), + ) + } + ) + return step_runs From e4b819094f9bed293d080f7f52b5b9248135f94b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Thu, 26 Jun 2025 16:31:29 +0200 Subject: [PATCH 17/17] Reduce wait time, compute cache after acquiring lock --- src/zenml/orchestrators/dag_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zenml/orchestrators/dag_runner.py b/src/zenml/orchestrators/dag_runner.py index a6a623eefb8..b41d171e8a7 100644 --- a/src/zenml/orchestrators/dag_runner.py +++ b/src/zenml/orchestrators/dag_runner.py @@ -161,7 +161,7 @@ def _prepare_node_run(self, node: str) -> None: break logger.debug(f"Waiting for {running_nodes} nodes to finish.") - time.sleep(10) + time.sleep(1) def _run_node(self, node: str) -> None: """Run a single node. @@ -171,14 +171,14 @@ def _run_node(self, node: str) -> None: Args: node: The node. """ + self._prepare_node_run(node) + if self.preparation_fn: run_required = self.preparation_fn(node) if not run_required: self._finish_node(node) return - self._prepare_node_run(node) - try: self.run_fn(node) self._finish_node(node)