Skip to content

Commit 20f8324

Browse files
authored
Merge pull request #3716 from zenml-io/feature/PRD-940-PRD-1044-stopping-pipelines
Ability to stop pipelines early + Kubernetes Orchestrator implementation
2 parents 32fe99a + bc85e53 commit 20f8324

File tree

22 files changed

+650
-77
lines changed

22 files changed

+650
-77
lines changed

src/zenml/cli/pipeline.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
ScheduleFilter,
3535
)
3636
from zenml.pipelines.pipeline_definition import Pipeline
37-
from zenml.utils import source_utils, uuid_utils
37+
from zenml.utils import run_utils, source_utils, uuid_utils
3838
from zenml.utils.yaml_utils import write_yaml
3939

4040
logger = get_logger(__name__)
@@ -511,6 +511,59 @@ def list_pipeline_runs(**kwargs: Any) -> None:
511511
cli_utils.print_page_info(pipeline_runs)
512512

513513

514+
@runs.command("stop")
515+
@click.argument("run_name_or_id", type=str, required=True)
516+
@click.option(
517+
"--graceful",
518+
"-g",
519+
is_flag=True,
520+
default=False,
521+
help="Use graceful shutdown (default is False).",
522+
)
523+
@click.option(
524+
"--yes",
525+
"-y",
526+
is_flag=True,
527+
default=False,
528+
help="Don't ask for confirmation.",
529+
)
530+
def stop_pipeline_run(
531+
run_name_or_id: str,
532+
graceful: bool = False,
533+
yes: bool = False,
534+
) -> None:
535+
"""Stop a running pipeline.
536+
537+
Args:
538+
run_name_or_id: The name or ID of the pipeline run to stop.
539+
graceful: If True, uses graceful shutdown. If False, forces immediate termination.
540+
yes: If set, don't ask for confirmation.
541+
"""
542+
# Ask for confirmation to stop run.
543+
if not yes:
544+
action = "gracefully stop" if graceful else "force stop"
545+
confirmation = cli_utils.confirmation(
546+
f"Are you sure you want to {action} pipeline run `{run_name_or_id}`?"
547+
)
548+
if not confirmation:
549+
cli_utils.declare("Not stopping the pipeline run.")
550+
return
551+
552+
# Stop run.
553+
try:
554+
run = Client().get_pipeline_run(name_id_or_prefix=run_name_or_id)
555+
run_utils.stop_run(run=run, graceful=graceful)
556+
action = "Gracefully stopped" if graceful else "Force stopped"
557+
cli_utils.declare(f"{action} pipeline run '{run.name}'.")
558+
except NotImplementedError:
559+
cli_utils.error(
560+
"The orchestrator used for this pipeline run does not support "
561+
f"{'gracefully' if graceful else 'forcefully'} stopping runs."
562+
)
563+
except Exception as e:
564+
cli_utils.error(f"Failed to stop pipeline run: {e}")
565+
566+
514567
@runs.command("delete")
515568
@click.argument("run_name_or_id", type=str, required=True)
516569
@click.option(

src/zenml/cli/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,6 +2214,8 @@ def get_execution_status_emoji(status: "ExecutionStatus") -> str:
22142214
return ":white_check_mark:"
22152215
if status == ExecutionStatus.CACHED:
22162216
return ":package:"
2217+
if status == ExecutionStatus.STOPPED or status == ExecutionStatus.STOPPING:
2218+
return ":stop_sign:"
22172219
raise RuntimeError(f"Unknown status: {status}")
22182220

22192221

src/zenml/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
416416
STATUS = "/status"
417417
STEP_CONFIGURATION = "/step-configuration"
418418
STEPS = "/steps"
419+
STOP = "/stop"
419420
TAGS = "/tags"
420421
TAG_RESOURCES = "/tag_resources"
421422
TRIGGERS = "/triggers"

src/zenml/enums.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,28 @@ class ZenMLServiceType(StrEnum):
7171

7272

7373
class ExecutionStatus(StrEnum):
74-
"""Enum that represents the current status of a step or pipeline run."""
74+
"""Enum that represents the execution status of a step or pipeline run."""
7575

7676
INITIALIZING = "initializing"
7777
FAILED = "failed"
7878
COMPLETED = "completed"
7979
RUNNING = "running"
8080
CACHED = "cached"
81+
STOPPED = "stopped"
82+
STOPPING = "stopping"
8183

8284
@property
8385
def is_finished(self) -> bool:
84-
"""Whether the execution status refers to a finished execution.
86+
"""Returns whether the execution status is in a finished state.
8587
8688
Returns:
87-
Whether the execution status refers to a finished execution.
89+
Whether the execution status is finished.
8890
"""
8991
return self in {
9092
ExecutionStatus.FAILED,
9193
ExecutionStatus.COMPLETED,
9294
ExecutionStatus.CACHED,
95+
ExecutionStatus.STOPPED,
9396
}
9497

9598

src/zenml/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ class IllegalOperationError(ZenMLBaseException):
122122
"""Raised when an illegal operation is attempted."""
123123

124124

125+
class RunStoppedException(ZenMLBaseException):
126+
"""Raised when a ZenML pipeline run gets stopped by the user."""
127+
128+
129+
class RunInterruptedException(ZenMLBaseException):
130+
"""Raised when a ZenML step gets interrupted for an unknown reason."""
131+
132+
125133
class MethodNotAllowedError(ZenMLBaseException):
126134
"""Raised when the server does not allow a request method."""
127135

src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -853,12 +853,16 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
853853
)["PipelineExecutionStatus"]
854854

855855
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
856-
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/PipelineState
857-
if status in ["Executing", "Stopping"]:
856+
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribePipelineExecution.html
857+
if status == "Executing":
858858
return ExecutionStatus.RUNNING
859-
elif status in ["Stopped", "Failed"]:
859+
elif status == "Stopping":
860+
return ExecutionStatus.STOPPING
861+
elif status == "Stopped":
862+
return ExecutionStatus.STOPPED
863+
elif status == "Failed":
860864
return ExecutionStatus.FAILED
861-
elif status in ["Succeeded"]:
865+
elif status == "Succeeded":
862866
return ExecutionStatus.COMPLETED
863867
else:
864868
raise ValueError("Unknown status for the pipeline execution.")

src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -515,14 +515,16 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
515515
return ExecutionStatus.INITIALIZING
516516
elif status in ["Running", "Finalizing"]:
517517
return ExecutionStatus.RUNNING
518+
elif status == "CancelRequested":
519+
return ExecutionStatus.STOPPING
520+
elif status == "Canceled":
521+
return ExecutionStatus.STOPPED
518522
elif status in [
519-
"CancelRequested",
520523
"Failed",
521-
"Canceled",
522524
"NotResponding",
523525
]:
524526
return ExecutionStatus.FAILED
525-
elif status in ["Completed"]:
527+
elif status == "Completed":
526528
return ExecutionStatus.COMPLETED
527529
else:
528530
raise ValueError("Unknown status for the pipeline job.")

src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
942942

943943
# Map the potential outputs to ZenML ExecutionStatus. Potential values:
944944
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_pipeline_execution.html#
945-
if status in [PipelineState.PIPELINE_STATE_UNSPECIFIED]:
945+
if status == PipelineState.PIPELINE_STATE_UNSPECIFIED:
946946
return run.status
947947
elif status in [
948948
PipelineState.PIPELINE_STATE_QUEUED,
@@ -954,14 +954,13 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus:
954954
PipelineState.PIPELINE_STATE_PAUSED,
955955
]:
956956
return ExecutionStatus.RUNNING
957-
elif status in [PipelineState.PIPELINE_STATE_SUCCEEDED]:
957+
elif status == PipelineState.PIPELINE_STATE_SUCCEEDED:
958958
return ExecutionStatus.COMPLETED
959-
960-
elif status in [
961-
PipelineState.PIPELINE_STATE_FAILED,
962-
PipelineState.PIPELINE_STATE_CANCELLING,
963-
PipelineState.PIPELINE_STATE_CANCELLED,
964-
]:
959+
elif status == PipelineState.PIPELINE_STATE_CANCELLING:
960+
return ExecutionStatus.STOPPING
961+
elif status == PipelineState.PIPELINE_STATE_CANCELLED:
962+
return ExecutionStatus.STOPPED
963+
elif status == PipelineState.PIPELINE_STATE_FAILED:
965964
return ExecutionStatus.FAILED
966965
else:
967966
raise ValueError("Unknown status for the pipeline job.")

src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class KubernetesOrchestratorSettings(BaseSettings):
6969
scheduling a pipeline.
7070
prevent_orchestrator_pod_caching: If `True`, the orchestrator pod will
7171
not try to compute cached steps before starting the step pods.
72+
pod_stop_grace_period: When stopping a pipeline run, the amount of
73+
seconds to wait for a step pod to shutdown gracefully.
7274
"""
7375

7476
synchronous: bool = True
@@ -88,6 +90,7 @@ class KubernetesOrchestratorSettings(BaseSettings):
8890
failed_jobs_history_limit: Optional[NonNegativeInt] = None
8991
ttl_seconds_after_finished: Optional[NonNegativeInt] = None
9092
prevent_orchestrator_pod_caching: bool = False
93+
pod_stop_grace_period: PositiveInt = 30
9194

9295

9396
class KubernetesOrchestratorConfig(

src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def submit_pipeline(
545545
successful_jobs_history_limit=settings.successful_jobs_history_limit,
546546
failed_jobs_history_limit=settings.failed_jobs_history_limit,
547547
ttl_seconds_after_finished=settings.ttl_seconds_after_finished,
548+
termination_grace_period_seconds=settings.pod_stop_grace_period,
548549
labels=orchestrator_pod_labels,
549550
)
550551

@@ -570,6 +571,7 @@ def submit_pipeline(
570571
env=environment,
571572
labels=orchestrator_pod_labels,
572573
mount_local_stores=self.config.is_local,
574+
termination_grace_period_seconds=settings.pod_stop_grace_period,
573575
)
574576

575577
kube_utils.create_and_wait_for_pod_to_start(
@@ -663,6 +665,92 @@ def get_orchestrator_run_id(self) -> str:
663665
f"{ENV_ZENML_KUBERNETES_RUN_ID}."
664666
)
665667

668+
def _stop_run(
669+
self, run: "PipelineRunResponse", graceful: bool = True
670+
) -> None:
671+
"""Stops a specific pipeline run by terminating step pods.
672+
673+
Args:
674+
run: The run that was executed by this orchestrator.
675+
graceful: If True, does nothing (lets the orchestrator and steps finish naturally).
676+
If False, stops all running step pods.
677+
678+
Raises:
679+
RuntimeError: If we fail to stop the run.
680+
"""
681+
# If graceful, do nothing and let the orchestrator handle the stop naturally
682+
if graceful:
683+
logger.info(
684+
"Graceful stop requested - the orchestrator pod will handle "
685+
"stopping naturally"
686+
)
687+
return
688+
689+
pods_stopped = []
690+
errors = []
691+
692+
# Find all pods with the orchestrator run ID label
693+
label_selector = f"run_id={kube_utils.sanitize_label(str(run.id))}"
694+
try:
695+
pods = self._k8s_core_api.list_namespaced_pod(
696+
namespace=self.config.kubernetes_namespace,
697+
label_selector=label_selector,
698+
)
699+
except Exception as e:
700+
raise RuntimeError(
701+
f"Failed to list step pods with run ID {run.id}: {e}"
702+
)
703+
704+
# Filter to only include running or pending pods
705+
for pod in pods.items:
706+
if pod.status.phase not in ["Running", "Pending"]:
707+
logger.debug(
708+
f"Skipping pod {pod.metadata.name} with status {pod.status.phase}"
709+
)
710+
continue
711+
712+
try:
713+
self._k8s_core_api.delete_namespaced_pod(
714+
name=pod.metadata.name,
715+
namespace=self.config.kubernetes_namespace,
716+
)
717+
pods_stopped.append(f"step pod: {pod.metadata.name}")
718+
logger.debug(
719+
f"Successfully initiated graceful stop of step pod: {pod.metadata.name}"
720+
)
721+
except Exception as e:
722+
error_msg = f"Failed to stop step pod {pod.metadata.name}: {e}"
723+
logger.warning(error_msg)
724+
errors.append(error_msg)
725+
726+
# Summary logging
727+
settings = cast(KubernetesOrchestratorSettings, self.get_settings(run))
728+
grace_period_seconds = settings.pod_stop_grace_period
729+
if pods_stopped:
730+
logger.debug(
731+
f"Successfully initiated graceful termination of: {', '.join(pods_stopped)}. "
732+
f"Pods will terminate within {grace_period_seconds} seconds."
733+
)
734+
735+
if errors:
736+
error_summary = "; ".join(errors)
737+
if not pods_stopped:
738+
# If nothing was stopped successfully, raise an error
739+
raise RuntimeError(
740+
f"Failed to stop pipeline run: {error_summary}"
741+
)
742+
else:
743+
# If some things were stopped but others failed, raise an error
744+
raise RuntimeError(
745+
f"Partial stop operation completed with errors: {error_summary}"
746+
)
747+
748+
# If no step pods were found and no errors occurred
749+
if not pods_stopped and not errors:
750+
logger.info(
751+
f"No running step pods found for pipeline run with ID: {run.id}"
752+
)
753+
666754
def get_pipeline_run_metadata(
667755
self, run_id: UUID
668756
) -> Dict[str, "MetadataType"]:

0 commit comments

Comments
 (0)