Skip to content

Commit 9e77cbd

Browse files
committed
WIP
1 parent 9e017cd commit 9e77cbd

12 files changed

+226
-84
lines changed

src/zenml/config/compiler.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Class for compiling ZenML pipelines into a serializable format."""
1515

1616
import copy
17+
import os
1718
import string
1819
from typing import (
1920
TYPE_CHECKING,
@@ -36,6 +37,11 @@
3637
StepConfigurationUpdate,
3738
StepSpec,
3839
)
40+
from zenml.constants import (
41+
ENV_ZENML_ACTIVE_STACK_ID,
42+
ENV_ZENML_ACTIVE_WORKSPACE_ID,
43+
ENV_ZENML_STORE_PREFIX,
44+
)
3945
from zenml.environment import get_run_environment_dict
4046
from zenml.exceptions import StackValidationError
4147
from zenml.models import PipelineDeploymentBase
@@ -50,6 +56,8 @@
5056

5157
from zenml.logger import get_logger
5258

59+
ENVIRONMENT_VARIABLE_PREFIX = "__ZENML__"
60+
5361
logger = get_logger(__file__)
5462

5563

@@ -104,13 +112,20 @@ def compile(
104112
pipeline.configuration.substitutions,
105113
)
106114

115+
pipeline_environment = finalize_environment_variables(
116+
pipeline.configuration.environment
117+
)
107118
pipeline_settings = self._filter_and_validate_settings(
108119
settings=pipeline.configuration.settings,
109120
configuration_level=ConfigurationLevel.PIPELINE,
110121
stack=stack,
111122
)
112123
with pipeline.__suppress_configure_warnings__():
113-
pipeline.configure(settings=pipeline_settings, merge=False)
124+
pipeline.configure(
125+
environment=pipeline_environment,
126+
settings=pipeline_settings,
127+
merge=False,
128+
)
114129

115130
settings_to_passdown = {
116131
key: settings
@@ -121,6 +136,7 @@ def compile(
121136
steps = {
122137
invocation_id: self._compile_step_invocation(
123138
invocation=invocation,
139+
pipeline_environment=pipeline_environment,
124140
pipeline_settings=settings_to_passdown,
125141
pipeline_extra=pipeline.configuration.extra,
126142
stack=stack,
@@ -210,6 +226,7 @@ def _apply_run_configuration(
210226
enable_artifact_metadata=config.enable_artifact_metadata,
211227
enable_artifact_visualization=config.enable_artifact_visualization,
212228
enable_step_logs=config.enable_step_logs,
229+
environment=config.environment,
213230
settings=config.settings,
214231
tags=config.tags,
215232
extra=config.extra,
@@ -427,6 +444,7 @@ def _get_step_spec(
427444
def _compile_step_invocation(
428445
self,
429446
invocation: "StepInvocation",
447+
pipeline_environment: Optional[Dict[str, Any]],
430448
pipeline_settings: Dict[str, "BaseSettings"],
431449
pipeline_extra: Dict[str, Any],
432450
stack: "Stack",
@@ -438,7 +456,9 @@ def _compile_step_invocation(
438456
439457
Args:
440458
invocation: The step invocation to compile.
441-
pipeline_settings: settings configured on the
459+
pipeline_environment: Environment variables configured for the
460+
pipeline.
461+
pipeline_settings: Settings configured on the
442462
pipeline of the step.
443463
pipeline_extra: Extra values configured on the pipeline of the step.
444464
stack: The stack on which the pipeline will be run.
@@ -463,6 +483,9 @@ def _compile_step_invocation(
463483
step.configuration.settings, stack=stack
464484
)
465485
step_spec = self._get_step_spec(invocation=invocation)
486+
step_environment = finalize_environment_variables(
487+
step.configuration.environment
488+
)
466489
step_settings = self._filter_and_validate_settings(
467490
settings=step.configuration.settings,
468491
configuration_level=ConfigurationLevel.STEP,
@@ -473,13 +496,15 @@ def _compile_step_invocation(
473496
step_on_success_hook_source = step.configuration.success_hook_source
474497

475498
step.configure(
499+
environment=pipeline_environment,
476500
settings=pipeline_settings,
477501
extra=pipeline_extra,
478502
on_failure=pipeline_failure_hook_source,
479503
on_success=pipeline_success_hook_source,
480504
merge=False,
481505
)
482506
step.configure(
507+
environment=step_environment,
483508
settings=step_settings,
484509
extra=step_extra,
485510
on_failure=step_on_failure_hook_source,
@@ -635,3 +660,50 @@ def convert_component_shortcut_settings_keys(
635660
)
636661

637662
settings[key] = component_settings
663+
664+
665+
def finalize_environment_variables(
666+
environment: Dict[str, Any],
667+
) -> Dict[str, str]:
668+
"""Finalize the user environment variables.
669+
670+
This function adds all __ZENML__ prefixed environment variables from the
671+
local client environment to the explicit user-defined variables.
672+
673+
Args:
674+
environment: The explicit user-defined environment variables.
675+
676+
Returns:
677+
The finalized user environment variables.
678+
"""
679+
environment = {key: str(value) for key, value in environment.items()}
680+
681+
for key, value in os.environ.items():
682+
if key.startswith(ENVIRONMENT_VARIABLE_PREFIX):
683+
key_without_prefix = key[len(ENVIRONMENT_VARIABLE_PREFIX) :]
684+
685+
if (
686+
key_without_prefix in environment
687+
and value != environment[key_without_prefix]
688+
):
689+
logger.warning(
690+
"Got multiple values for environment variable `%s`.",
691+
key_without_prefix,
692+
)
693+
else:
694+
environment[key_without_prefix] = value
695+
696+
finalized_env = {}
697+
698+
for key, value in environment.items():
699+
if key.upper().startswith(ENV_ZENML_STORE_PREFIX) or key.upper() in [
700+
ENV_ZENML_ACTIVE_WORKSPACE_ID,
701+
ENV_ZENML_ACTIVE_STACK_ID,
702+
]:
703+
logger.warning(
704+
"Not allowed to set `%s` config environment variable.", key
705+
)
706+
continue
707+
finalized_env[key] = str(value)
708+
709+
return finalized_env

src/zenml/config/pipeline_configurations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
4040
enable_artifact_metadata: Optional[bool] = None
4141
enable_artifact_visualization: Optional[bool] = None
4242
enable_step_logs: Optional[bool] = None
43+
environment: Dict[str, Any] = {}
4344
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
4445
tags: Optional[List[str]] = None
4546
extra: Dict[str, Any] = {}

src/zenml/config/pipeline_run_configuration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class PipelineRunConfiguration(
4444
default=None, union_mode="left_to_right"
4545
)
4646
steps: Dict[str, StepConfigurationUpdate] = {}
47+
environment: Dict[str, Any] = {}
4748
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
4849
tags: Optional[List[str]] = None
4950
extra: Dict[str, Any] = {}

src/zenml/config/step_configurations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class StepConfigurationUpdate(StrictBaseModel):
148148
step_operator: Optional[str] = None
149149
experiment_tracker: Optional[str] = None
150150
parameters: Dict[str, Any] = {}
151+
environment: Dict[str, Any] = {}
151152
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
152153
extra: Dict[str, Any] = {}
153154
failure_hook_source: Optional[SourceWithValidator] = None

src/zenml/orchestrators/cache_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def generate_cache_key(
9999
hash_.update(key.encode())
100100
hash_.update(str(value).encode())
101101

102+
# User-defined environment variables
103+
for key, value in sorted(step.config.environment.items()):
104+
hash_.update(key.encode())
105+
hash_.update(str(value).encode())
106+
102107
return hash_.hexdigest()
103108

104109

src/zenml/orchestrators/step_runner.py

Lines changed: 80 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@
5656
parse_return_type_annotations,
5757
resolve_type_annotation,
5858
)
59-
from zenml.utils import materializer_utils, source_utils, string_utils
59+
from zenml.utils import (
60+
env_utils,
61+
materializer_utils,
62+
source_utils,
63+
string_utils,
64+
)
6065
from zenml.utils.typing_utils import get_origin, is_union
6166

6267
if TYPE_CHECKING:
@@ -183,86 +188,90 @@ def run(
183188
)
184189

185190
step_failed = False
186-
try:
187-
return_values = step_instance.call_entrypoint(
188-
**function_params
189-
)
190-
except BaseException as step_exception: # noqa: E722
191-
step_failed = True
192-
if not handle_bool_env_var(
193-
ENV_ZENML_IGNORE_FAILURE_HOOK, False
194-
):
195-
if (
196-
failure_hook_source
197-
:= self.configuration.failure_hook_source
198-
):
199-
logger.info("Detected failure hook. Running...")
200-
self.load_and_run_hook(
201-
failure_hook_source,
202-
step_exception=step_exception,
203-
)
204-
raise
205-
finally:
191+
with env_utils.temporary_environment(step_run.config.environment):
206192
try:
207-
step_run_metadata = self._stack.get_step_run_metadata(
208-
info=step_run_info,
209-
)
210-
publish_step_run_metadata(
211-
step_run_id=step_run_info.step_run_id,
212-
step_run_metadata=step_run_metadata,
213-
)
214-
self._stack.cleanup_step_run(
215-
info=step_run_info, step_failed=step_failed
193+
return_values = step_instance.call_entrypoint(
194+
**function_params
216195
)
217-
if not step_failed:
196+
except BaseException as step_exception: # noqa: E722
197+
step_failed = True
198+
if not handle_bool_env_var(
199+
ENV_ZENML_IGNORE_FAILURE_HOOK, False
200+
):
218201
if (
219-
success_hook_source
220-
:= self.configuration.success_hook_source
202+
failure_hook_source
203+
:= self.configuration.failure_hook_source
221204
):
222-
logger.info("Detected success hook. Running...")
205+
logger.info("Detected failure hook. Running...")
223206
self.load_and_run_hook(
224-
success_hook_source,
225-
step_exception=None,
207+
failure_hook_source,
208+
step_exception=step_exception,
226209
)
227-
228-
# Store and publish the output artifacts of the step function.
229-
output_data = self._validate_outputs(
230-
return_values, output_annotations
231-
)
232-
artifact_metadata_enabled = is_setting_enabled(
233-
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
234-
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
210+
raise
211+
finally:
212+
try:
213+
step_run_metadata = self._stack.get_step_run_metadata(
214+
info=step_run_info,
235215
)
236-
artifact_visualization_enabled = is_setting_enabled(
237-
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
238-
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
216+
publish_step_run_metadata(
217+
step_run_id=step_run_info.step_run_id,
218+
step_run_metadata=step_run_metadata,
239219
)
240-
output_artifacts = self._store_output_artifacts(
241-
output_data=output_data,
242-
output_artifact_uris=output_artifact_uris,
243-
output_materializers=output_materializers,
244-
output_annotations=output_annotations,
245-
artifact_metadata_enabled=artifact_metadata_enabled,
246-
artifact_visualization_enabled=artifact_visualization_enabled,
220+
self._stack.cleanup_step_run(
221+
info=step_run_info, step_failed=step_failed
247222
)
248-
249-
if (
250-
model_version := step_run.model_version
251-
or pipeline_run.model_version
252-
):
253-
from zenml.orchestrators import step_run_utils
254-
255-
step_run_utils.link_output_artifacts_to_model_version(
256-
artifacts={
257-
k: [v] for k, v in output_artifacts.items()
258-
},
259-
model_version=model_version,
223+
if not step_failed:
224+
if (
225+
success_hook_source
226+
:= self.configuration.success_hook_source
227+
):
228+
logger.info(
229+
"Detected success hook. Running..."
230+
)
231+
self.load_and_run_hook(
232+
success_hook_source,
233+
step_exception=None,
234+
)
235+
236+
# Store and publish the output artifacts of the step function.
237+
output_data = self._validate_outputs(
238+
return_values, output_annotations
260239
)
261-
finally:
262-
step_context._cleanup_registry.execute_callbacks(
263-
raise_on_exception=False
264-
)
265-
StepContext._clear() # Remove the step context singleton
240+
artifact_metadata_enabled = is_setting_enabled(
241+
is_enabled_on_step=step_run_info.config.enable_artifact_metadata,
242+
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_metadata,
243+
)
244+
artifact_visualization_enabled = is_setting_enabled(
245+
is_enabled_on_step=step_run_info.config.enable_artifact_visualization,
246+
is_enabled_on_pipeline=step_run_info.pipeline.enable_artifact_visualization,
247+
)
248+
output_artifacts = self._store_output_artifacts(
249+
output_data=output_data,
250+
output_artifact_uris=output_artifact_uris,
251+
output_materializers=output_materializers,
252+
output_annotations=output_annotations,
253+
artifact_metadata_enabled=artifact_metadata_enabled,
254+
artifact_visualization_enabled=artifact_visualization_enabled,
255+
)
256+
257+
if (
258+
model_version := step_run.model_version
259+
or pipeline_run.model_version
260+
):
261+
from zenml.orchestrators import step_run_utils
262+
263+
step_run_utils.link_output_artifacts_to_model_version(
264+
artifacts={
265+
k: [v]
266+
for k, v in output_artifacts.items()
267+
},
268+
model_version=model_version,
269+
)
270+
finally:
271+
step_context._cleanup_registry.execute_callbacks(
272+
raise_on_exception=False
273+
)
274+
StepContext._clear() # Remove the step context singleton
266275

267276
# Update the status and output artifacts of the step run.
268277
output_artifact_ids = {

0 commit comments

Comments
 (0)