diff --git a/docs/link_checker.py b/docs/link_checker.py index 9af75181bac..6e5ecf9ef6d 100644 --- a/docs/link_checker.py +++ b/docs/link_checker.py @@ -59,7 +59,6 @@ import os import re import sys -import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Tuple @@ -333,43 +332,68 @@ def validate_urls( """ if not urls: return {} - + results = {} # Count and report GitHub links that will be skipped in validation from urllib.parse import urlparse - github_urls = [url for url in urls if urlparse(url).hostname and urlparse(url).hostname.endswith("github.com")] - other_urls = [url for url in urls if urlparse(url).hostname and not urlparse(url).hostname.endswith("github.com")] - + + github_urls = [ + url + for url in urls + if urlparse(url).hostname + and urlparse(url).hostname.endswith("github.com") + ] + other_urls = [ + url + for url in urls + if urlparse(url).hostname + and not urlparse(url).hostname.endswith("github.com") + ] + print(f"Validating {len(urls)} links...") - print(f"Note: {len(github_urls)} GitHub links will be automatically marked as valid (skipping validation)") - + print( + f"Note: {len(github_urls)} GitHub links will be automatically marked as valid (skipping validation)" + ) + # Use moderate settings for non-GitHub URLs actual_max_workers = min(6, max_workers) - - print(f"Using {actual_max_workers} workers for remaining {len(other_urls)} links...") - + + print( + f"Using {actual_max_workers} workers for remaining {len(other_urls)} links..." + ) + with ThreadPoolExecutor(max_workers=actual_max_workers) as executor: future_to_url = {} - + # Submit all URLs (GitHub links will be auto-skipped in check_link_validity) for url in urls: - future_to_url[executor.submit(check_link_validity, url, timeout=15)] = url - + future_to_url[ + executor.submit(check_link_validity, url, timeout=15) + ] = url + # Process results for i, future in enumerate(as_completed(future_to_url), 1): url = future_to_url[future] try: _, is_valid, error_message, status_code = future.result() results[url] = (is_valid, error_message, status_code) - + if "github.com" in url: - print(f" Checked URL {i}/{len(urls)} [github.com]: ✓ Skipped (automatically marked valid)") + print( + f" Checked URL {i}/{len(urls)} [github.com]: ✓ Skipped (automatically marked valid)" + ) else: status = "✅ Valid" if is_valid else f"❌ {error_message}" - domain = url.split('/')[2] if '://' in url and '/' in url.split('://', 1)[1] else 'unknown' - print(f" Checked URL {i}/{len(urls)} [{domain}]: {status}") - + domain = ( + url.split("/")[2] + if "://" in url and "/" in url.split("://", 1)[1] + else "unknown" + ) + print( + f" Checked URL {i}/{len(urls)} [{domain}]: {status}" + ) + except Exception as e: results[url] = (False, str(e), None) print(f" Error checking URL {i}/{len(urls)}: {e}") diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index c62214e2dd5..3703a0a996a 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -292,6 +292,10 @@ def _bypass() -> None: artifacts=step_run.outputs, model_version=model_version, ) + step_run_utils.cascade_tags_for_output_artifacts( + artifacts=step_run.outputs, + tags=pipeline_run.config.tags, + ) except: # noqa: E722 logger.error(f"Pipeline run `{pipeline_run.name}` failed.") diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index 9de52702906..d2a4bda9322 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -13,8 +13,9 @@ # permissions and limitations under the License. """Utilities for creating step runs.""" -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union +from zenml import Tag, add_tags from zenml.client import Client from zenml.config.step_configurations import Step from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH @@ -333,6 +334,11 @@ def create_cached_step_runs( model_version=model_version, ) + cascade_tags_for_output_artifacts( + artifacts=step_run.outputs, + tags=pipeline_run.config.tags, + ) + logger.info("Using cached version of step `%s`.", invocation_id) cached_invocations.add(invocation_id) @@ -382,3 +388,27 @@ def link_output_artifacts_to_model_version( artifact_version=output_artifact, model_version=model_version, ) + + +def cascade_tags_for_output_artifacts( + artifacts: Dict[str, List[ArtifactVersionResponse]], + tags: Optional[List[Union[str, Tag]]] = None, +) -> None: + """Tag the outputs of a step run. + + Args: + artifacts: The step output artifacts. + tags: The tags to add to the artifacts. + """ + if tags is None: + return + + for output_artifacts in artifacts.values(): + for output_artifact in output_artifacts: + cascade_tags = [ + t for t in tags if isinstance(t, Tag) and t.cascade + ] + add_tags( + tags=[t.name for t in cascade_tags], + artifact_version_id=output_artifact.id, + ) diff --git a/src/zenml/utils/tag_utils.py b/src/zenml/utils/tag_utils.py index 791c8cb498f..54d1bd1fada 100644 --- a/src/zenml/utils/tag_utils.py +++ b/src/zenml/utils/tag_utils.py @@ -348,10 +348,10 @@ def add_tags( if isinstance(tag, Tag): tag_model = client.get_tag(tag.name) - if tag.exclusive != tag_model.exclusive: + if bool(tag.exclusive) != tag_model.exclusive: raise ValueError( - f"The tag `{tag.name}` is an " - f"{'exclusive' if tag_model.exclusive else 'non-exclusive'} " + f"The tag `{tag.name}` is " + f"{'an exclusive' if tag_model.exclusive else 'a non-exclusive'} " "tag. Please update it before attaching it to a resource." ) if tag.cascade is not None: diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index bb1191429e1..be0af8edb5e 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -11354,13 +11354,10 @@ def _attach_tags_to_resources( except EntityExistsError: if isinstance(tag, tag_utils.Tag): tag_schema = self._get_tag_schema(tag.name, session) - if ( - tag.exclusive is not None - and tag.exclusive != tag_schema.exclusive - ): + if bool(tag.exclusive) != tag_schema.exclusive: raise ValueError( - f"Tag `{tag_schema.name}` has been defined as a " - f"{'exclusive' if tag_schema.exclusive else 'non-exclusive'} " + f"Tag `{tag_schema.name}` has been defined as " + f"{'an exclusive' if tag_schema.exclusive else 'a non-exclusive'} " "tag. Please update it before attaching it to resources." ) else: diff --git a/tests/integration/functional/utils/test_tag_utils.py b/tests/integration/functional/utils/test_tag_utils.py index 2f0d51da8d9..2b12e1d6788 100644 --- a/tests/integration/functional/utils/test_tag_utils.py +++ b/tests/integration/functional/utils/test_tag_utils.py @@ -13,11 +13,14 @@ # permissions and limitations under the License. +import os from typing import Annotated, Tuple import pytest from zenml import ArtifactConfig, Tag, add_tags, pipeline, remove_tags, step +from zenml.constants import ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING +from zenml.enums import ExecutionStatus @step @@ -139,3 +142,91 @@ def test_tag_utils(clean_client): clean_client.update_tag( tag_name_or_id=non_exclusive_tag.id, exclusive=True ) + + +def test_cascade_tags_for_output_artifacts_of_cached_pipeline_run( + clean_client, +): + """Test that the cascade tags are added to the output artifacts of a cached step.""" + # Run the pipeline once without caching + pipeline_to_tag() + + pipeline_runs = clean_client.list_pipeline_runs(sort_by="created") + assert len(pipeline_runs.items) == 1 + assert ( + pipeline_runs.items[0].steps["step_single_output"].status + == ExecutionStatus.COMPLETED + ) + assert "cascade_tag" in [ + t.name + for t in pipeline_runs.items[0] + .steps["step_single_output"] + .outputs["single"][0] + .tags + ] + + # Run it once again with caching + pipeline_to_tag.configure(enable_cache=True) + pipeline_to_tag() + pipeline_runs = clean_client.list_pipeline_runs(sort_by="created") + assert len(pipeline_runs.items) == 2 + assert ( + pipeline_runs.items[1].steps["step_single_output"].status + == ExecutionStatus.CACHED + ) + + # Run it once again with caching and a new cascade tag + pipeline_to_tag.configure( + tags=[Tag(name="second_cascade_tag", cascade=True)] + ) + pipeline_to_tag() + pipeline_runs = clean_client.list_pipeline_runs(sort_by="created") + assert len(pipeline_runs.items) == 3 + assert ( + pipeline_runs.items[2].steps["step_single_output"].status + == ExecutionStatus.CACHED + ) + + assert "second_cascade_tag" in [ + t.name + for t in pipeline_runs.items[0] + .steps["step_single_output"] + .outputs["single"][0] + .tags + ] + assert "second_cascade_tag" in [ + t.name + for t in pipeline_runs.items[2] + .steps["step_single_output"] + .outputs["single"][0] + .tags + ] + + # Run it once again with caching (preventing client side caching) and a new cascade tag + os.environ[ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING] = "true" + pipeline_to_tag.configure( + tags=[Tag(name="third_cascade_tag", cascade=True)] + ) + pipeline_to_tag() + + pipeline_runs = clean_client.list_pipeline_runs(sort_by="created") + assert len(pipeline_runs.items) == 4 + assert ( + pipeline_runs.items[3].steps["step_single_output"].status + == ExecutionStatus.CACHED + ) + + assert "third_cascade_tag" in [ + t.name + for t in pipeline_runs.items[0] + .steps["step_single_output"] + .outputs["single"][0] + .tags + ] + assert "third_cascade_tag" in [ + t.name + for t in pipeline_runs.items[3] + .steps["step_single_output"] + .outputs["single"][0] + .tags + ]