Skip to content

Utilizing cascading tags for cached step runs #3655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions docs/link_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import os
import re
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -333,43 +332,68 @@ def validate_urls(
"""
if not urls:
return {}

results = {}

# Count and report GitHub links that will be skipped in validation
from urllib.parse import urlparse
github_urls = [url for url in urls if urlparse(url).hostname and urlparse(url).hostname.endswith("github.com")]
other_urls = [url for url in urls if urlparse(url).hostname and not urlparse(url).hostname.endswith("github.com")]


github_urls = [
url
for url in urls
if urlparse(url).hostname
and urlparse(url).hostname.endswith("github.com")
]
other_urls = [
url
for url in urls
if urlparse(url).hostname
and not urlparse(url).hostname.endswith("github.com")
]

print(f"Validating {len(urls)} links...")
print(f"Note: {len(github_urls)} GitHub links will be automatically marked as valid (skipping validation)")

print(
f"Note: {len(github_urls)} GitHub links will be automatically marked as valid (skipping validation)"
)

# Use moderate settings for non-GitHub URLs
actual_max_workers = min(6, max_workers)

print(f"Using {actual_max_workers} workers for remaining {len(other_urls)} links...")


print(
f"Using {actual_max_workers} workers for remaining {len(other_urls)} links..."
)

with ThreadPoolExecutor(max_workers=actual_max_workers) as executor:
future_to_url = {}

# Submit all URLs (GitHub links will be auto-skipped in check_link_validity)
for url in urls:
future_to_url[executor.submit(check_link_validity, url, timeout=15)] = url

future_to_url[
executor.submit(check_link_validity, url, timeout=15)
] = url

# Process results
for i, future in enumerate(as_completed(future_to_url), 1):
url = future_to_url[future]
try:
_, is_valid, error_message, status_code = future.result()
results[url] = (is_valid, error_message, status_code)

if "github.com" in url:
print(f" Checked URL {i}/{len(urls)} [github.com]: ✓ Skipped (automatically marked valid)")
print(
f" Checked URL {i}/{len(urls)} [github.com]: ✓ Skipped (automatically marked valid)"
)
else:
status = "✅ Valid" if is_valid else f"❌ {error_message}"
domain = url.split('/')[2] if '://' in url and '/' in url.split('://', 1)[1] else 'unknown'
print(f" Checked URL {i}/{len(urls)} [{domain}]: {status}")

domain = (
url.split("/")[2]
if "://" in url and "/" in url.split("://", 1)[1]
else "unknown"
)
print(
f" Checked URL {i}/{len(urls)} [{domain}]: {status}"
)

except Exception as e:
results[url] = (False, str(e), None)
print(f" Error checking URL {i}/{len(urls)}: {e}")
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ def _bypass() -> None:
artifacts=step_run.outputs,
model_version=model_version,
)
step_run_utils.cascade_tags_for_output_artifacts(
artifacts=step_run.outputs,
tags=pipeline_run.config.tags,
)

except: # noqa: E722
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
Expand Down
32 changes: 31 additions & 1 deletion src/zenml/orchestrators/step_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# permissions and limitations under the License.
"""Utilities for creating step runs."""

from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union

from zenml import Tag, add_tags
from zenml.client import Client
from zenml.config.step_configurations import Step
from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH
Expand Down Expand Up @@ -333,6 +334,11 @@ def create_cached_step_runs(
model_version=model_version,
)

cascade_tags_for_output_artifacts(
artifacts=step_run.outputs,
tags=pipeline_run.config.tags,
)

logger.info("Using cached version of step `%s`.", invocation_id)
cached_invocations.add(invocation_id)

Expand Down Expand Up @@ -382,3 +388,27 @@ def link_output_artifacts_to_model_version(
artifact_version=output_artifact,
model_version=model_version,
)


def cascade_tags_for_output_artifacts(
artifacts: Dict[str, List[ArtifactVersionResponse]],
tags: Optional[List[Union[str, Tag]]] = None,
) -> None:
"""Tag the outputs of a step run.

Args:
artifacts: The step output artifacts.
tags: The tags to add to the artifacts.
"""
if tags is None:
return

for output_artifacts in artifacts.values():
for output_artifact in output_artifacts:
cascade_tags = [
t for t in tags if isinstance(t, Tag) and t.cascade
]
add_tags(
tags=[t.name for t in cascade_tags],
artifact_version_id=output_artifact.id,
)
6 changes: 3 additions & 3 deletions src/zenml/utils/tag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ def add_tags(
if isinstance(tag, Tag):
tag_model = client.get_tag(tag.name)

if tag.exclusive != tag_model.exclusive:
if bool(tag.exclusive) != tag_model.exclusive:
raise ValueError(
f"The tag `{tag.name}` is an "
f"{'exclusive' if tag_model.exclusive else 'non-exclusive'} "
f"The tag `{tag.name}` is "
f"{'an exclusive' if tag_model.exclusive else 'a non-exclusive'} "
"tag. Please update it before attaching it to a resource."
)
if tag.cascade is not None:
Expand Down
9 changes: 3 additions & 6 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11354,13 +11354,10 @@ def _attach_tags_to_resources(
except EntityExistsError:
if isinstance(tag, tag_utils.Tag):
tag_schema = self._get_tag_schema(tag.name, session)
if (
tag.exclusive is not None
and tag.exclusive != tag_schema.exclusive
):
if bool(tag.exclusive) != tag_schema.exclusive:
raise ValueError(
f"Tag `{tag_schema.name}` has been defined as a "
f"{'exclusive' if tag_schema.exclusive else 'non-exclusive'} "
f"Tag `{tag_schema.name}` has been defined as "
f"{'an exclusive' if tag_schema.exclusive else 'a non-exclusive'} "
"tag. Please update it before attaching it to resources."
)
else:
Expand Down
91 changes: 91 additions & 0 deletions tests/integration/functional/utils/test_tag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# permissions and limitations under the License.


import os
from typing import Annotated, Tuple

import pytest

from zenml import ArtifactConfig, Tag, add_tags, pipeline, remove_tags, step
from zenml.constants import ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING
from zenml.enums import ExecutionStatus


@step
Expand Down Expand Up @@ -139,3 +142,91 @@ def test_tag_utils(clean_client):
clean_client.update_tag(
tag_name_or_id=non_exclusive_tag.id, exclusive=True
)


def test_cascade_tags_for_output_artifacts_of_cached_pipeline_run(
clean_client,
):
"""Test that the cascade tags are added to the output artifacts of a cached step."""
# Run the pipeline once without caching
pipeline_to_tag()

pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
assert len(pipeline_runs.items) == 1
assert (
pipeline_runs.items[0].steps["step_single_output"].status
== ExecutionStatus.COMPLETED
)
assert "cascade_tag" in [
t.name
for t in pipeline_runs.items[0]
.steps["step_single_output"]
.outputs["single"][0]
.tags
]

# Run it once again with caching
pipeline_to_tag.configure(enable_cache=True)
pipeline_to_tag()
pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
assert len(pipeline_runs.items) == 2
assert (
pipeline_runs.items[1].steps["step_single_output"].status
== ExecutionStatus.CACHED
)

# Run it once again with caching and a new cascade tag
pipeline_to_tag.configure(
tags=[Tag(name="second_cascade_tag", cascade=True)]
)
pipeline_to_tag()
pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
assert len(pipeline_runs.items) == 3
assert (
pipeline_runs.items[2].steps["step_single_output"].status
== ExecutionStatus.CACHED
)

assert "second_cascade_tag" in [
t.name
for t in pipeline_runs.items[0]
.steps["step_single_output"]
.outputs["single"][0]
.tags
]
assert "second_cascade_tag" in [
t.name
for t in pipeline_runs.items[2]
.steps["step_single_output"]
.outputs["single"][0]
.tags
]

# Run it once again with caching (preventing client side caching) and a new cascade tag
os.environ[ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING] = "true"
pipeline_to_tag.configure(
tags=[Tag(name="third_cascade_tag", cascade=True)]
)
pipeline_to_tag()

pipeline_runs = clean_client.list_pipeline_runs(sort_by="created")
assert len(pipeline_runs.items) == 4
assert (
pipeline_runs.items[3].steps["step_single_output"].status
== ExecutionStatus.CACHED
)

assert "third_cascade_tag" in [
t.name
for t in pipeline_runs.items[0]
.steps["step_single_output"]
.outputs["single"][0]
.tags
]
assert "third_cascade_tag" in [
t.name
for t in pipeline_runs.items[3]
.steps["step_single_output"]
.outputs["single"][0]
.tags
]
Loading