Skip to content

Commit d005dc3

Browse files
committed
More changes
1 parent a88e52c commit d005dc3

File tree

10 files changed

+164
-64
lines changed

10 files changed

+164
-64
lines changed

src/zenml/client.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,12 +1211,7 @@ def create_stack(
12111211
)
12121212
stack_components[c_type] = [component.id]
12131213

1214-
secret_ids = []
1215-
for secret in secrets or []:
1216-
if isinstance(secret, UUID):
1217-
secret_ids.append(secret)
1218-
else:
1219-
secret_ids.append(self.get_secret(secret).id)
1214+
secret_ids = [self.get_secret(secret).id for secret in secrets or []]
12201215

12211216
stack = StackRequest(
12221217
name=name,
@@ -1323,6 +1318,8 @@ def update_stack(
13231318
component_updates: Optional[
13241319
Dict[StackComponentType, List[Union[UUID, str]]]
13251320
] = None,
1321+
add_secrets: Optional[List[Union[UUID, str]]] = None,
1322+
remove_secrets: Optional[List[Union[UUID, str]]] = None,
13261323
) -> StackResponse:
13271324
"""Updates a stack and its components.
13281325
@@ -1334,6 +1331,8 @@ def update_stack(
13341331
description: the new description of the stack.
13351332
component_updates: dictionary which maps stack component types to
13361333
lists of new stack component names or ids.
1334+
add_secrets: The secrets to add to the stack.
1335+
remove_secrets: The secrets to remove from the stack.
13371336
13381337
Returns:
13391338
The model of the updated stack.
@@ -1391,6 +1390,16 @@ def update_stack(
13911390
}
13921391
update_model.labels = existing_labels
13931392

1393+
if add_secrets:
1394+
update_model.add_secrets = [
1395+
self.get_secret(secret).id for secret in add_secrets
1396+
]
1397+
1398+
if remove_secrets:
1399+
update_model.remove_secrets = [
1400+
self.get_secret(secret).id for secret in remove_secrets
1401+
]
1402+
13941403
updated_stack = self.zen_store.update_stack(
13951404
stack_id=stack.id,
13961405
stack_update=update_model,
@@ -2000,6 +2009,7 @@ def create_stack_component(
20002009
component_type: StackComponentType,
20012010
configuration: Dict[str, str],
20022011
labels: Optional[Dict[str, Any]] = None,
2012+
secrets: Optional[List[Union[UUID, str]]] = None,
20032013
) -> "ComponentResponse":
20042014
"""Registers a stack component.
20052015
@@ -2009,7 +2019,7 @@ def create_stack_component(
20092019
component_type: The type of the stack component.
20102020
configuration: The configuration of the stack component.
20112021
labels: The labels of the stack component.
2012-
2022+
secrets: The secrets of the stack component.
20132023
Returns:
20142024
The model of the registered component.
20152025
"""
@@ -2030,12 +2040,15 @@ def create_stack_component(
20302040
assert validated_config is not None
20312041
warn_if_config_server_mismatch(validated_config)
20322042

2043+
secret_ids = [self.get_secret(secret).id for secret in secrets or []]
2044+
20332045
create_component_model = ComponentRequest(
20342046
name=name,
20352047
type=component_type,
20362048
flavor=flavor,
20372049
configuration=configuration,
20382050
labels=labels,
2051+
secrets=secret_ids,
20392052
)
20402053

20412054
# Register the new model
@@ -2053,6 +2066,8 @@ def update_stack_component(
20532066
disconnect: Optional[bool] = None,
20542067
connector_id: Optional[UUID] = None,
20552068
connector_resource_id: Optional[str] = None,
2069+
add_secrets: Optional[List[Union[UUID, str]]] = None,
2070+
remove_secrets: Optional[List[Union[UUID, str]]] = None,
20562071
) -> ComponentResponse:
20572072
"""Updates a stack component.
20582073
@@ -2068,6 +2083,8 @@ def update_stack_component(
20682083
connector_id: The new connector id of the stack component.
20692084
connector_resource_id: The new connector resource id of the
20702085
stack component.
2086+
add_secrets: The secrets to add to the stack component.
2087+
remove_secrets: The secrets to remove from the stack component.
20712088
20722089
Returns:
20732090
The updated stack component.
@@ -2150,6 +2167,16 @@ def update_stack_component(
21502167
existing_component.connector_resource_id
21512168
)
21522169

2170+
if add_secrets:
2171+
update_model.add_secrets = [
2172+
self.get_secret(secret).id for secret in add_secrets
2173+
]
2174+
2175+
if remove_secrets:
2176+
update_model.remove_secrets = [
2177+
self.get_secret(secret).id for secret in remove_secrets
2178+
]
2179+
21532180
# Send the updated component to the ZenStore
21542181
return self.zen_store.update_stack_component(
21552182
component_id=component.id,

src/zenml/config/compiler.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,11 @@
3737
StepConfigurationUpdate,
3838
StepSpec,
3939
)
40-
from zenml.constants import (
41-
ENV_ZENML_ACTIVE_PROJECT_ID,
42-
ENV_ZENML_ACTIVE_STACK_ID,
43-
ENV_ZENML_STORE_PREFIX,
44-
)
4540
from zenml.environment import get_run_environment_dict
4641
from zenml.exceptions import StackValidationError
4742
from zenml.models import PipelineDeploymentBase
4843
from zenml.pipelines.run_utils import get_default_run_name
49-
from zenml.utils import pydantic_utils, settings_utils
44+
from zenml.utils import pydantic_utils, secret_utils, settings_utils
5045

5146
if TYPE_CHECKING:
5247
from zenml.config.source import Source
@@ -115,7 +110,9 @@ def compile(
115110
pipeline_environment = finalize_environment_variables(
116111
pipeline.configuration.environment
117112
)
118-
pipeline_secrets = pipeline.configuration.secrets.copy()
113+
pipeline_secrets = secret_utils.resolve_and_verify_secrets(
114+
pipeline.configuration.secrets
115+
)
119116
pipeline_settings = self._filter_and_validate_settings(
120117
settings=pipeline.configuration.settings,
121118
configuration_level=ConfigurationLevel.PIPELINE,
@@ -509,10 +506,10 @@ def _compile_step_invocation(
509506
step.configuration.settings, stack=stack
510507
)
511508
step_spec = self._get_step_spec(invocation=invocation)
512-
step_environment = finalize_environment_variables(
513-
step.configuration.environment
509+
step_environment = step.configuration.environment
510+
step_secrets = secret_utils.resolve_and_verify_secrets(
511+
step.configuration.secrets
514512
)
515-
step_secrets = step.configuration.secrets
516513
step_settings = self._filter_and_validate_settings(
517514
settings=step.configuration.settings,
518515
configuration_level=ConfigurationLevel.STEP,
@@ -722,17 +719,4 @@ def finalize_environment_variables(
722719
else:
723720
environment[key_without_prefix] = value
724721

725-
finalized_env = {}
726-
727-
for key, value in environment.items():
728-
if key.upper().startswith(ENV_ZENML_STORE_PREFIX) or key.upper() in [
729-
ENV_ZENML_ACTIVE_PROJECT_ID,
730-
ENV_ZENML_ACTIVE_STACK_ID,
731-
]:
732-
logger.warning(
733-
"Not allowed to set `%s` config environment variable.", key
734-
)
735-
continue
736-
finalized_env[key] = str(value)
737-
738-
return finalized_env
722+
return environment

src/zenml/config/pipeline_configurations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from datetime import datetime
1717
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
18+
from uuid import UUID
1819

1920
from pydantic import SerializeAsAny, field_validator
2021

@@ -42,7 +43,7 @@ class PipelineConfigurationUpdate(StrictBaseModel):
4243
enable_artifact_visualization: Optional[bool] = None
4344
enable_step_logs: Optional[bool] = None
4445
environment: Dict[str, Any] = {}
45-
secrets: List[str] = []
46+
secrets: List[Union[str, UUID]] = []
4647
enable_pipeline_logs: Optional[bool] = None
4748
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
4849
tags: Optional[List[Union[str, "Tag"]]] = None

src/zenml/config/pipeline_run_configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class PipelineRunConfiguration(
4747
)
4848
steps: Dict[str, StepConfigurationUpdate] = {}
4949
environment: Dict[str, Any] = {}
50-
secrets: List[str] = []
50+
secrets: List[Union[str, UUID]] = []
5151
settings: Dict[str, SerializeAsAny[BaseSettings]] = {}
5252
tags: Optional[List[Union[str, Tag]]] = None
5353
extra: Dict[str, Any] = {}

src/zenml/orchestrators/step_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,14 @@ def run(
189189
)
190190

191191
step_failed = False
192-
environment = env_utils.gather_step_environment(
192+
environment = env_utils.get_step_environment(
193193
step_config=step_run.config, stack=self._stack
194194
)
195+
secret_environment = env_utils.get_step_secret_environment(
196+
step_config=step_run.config, stack=self._stack
197+
)
198+
environment.update(secret_environment)
199+
195200
with env_utils.temporary_environment(environment):
196201
try:
197202
return_values = step_instance.call_entrypoint(

src/zenml/steps/base_step.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
materializer_utils,
6161
notebook_utils,
6262
pydantic_utils,
63-
secret_utils,
6463
settings_utils,
6564
source_code_utils,
6665
source_utils,
@@ -716,9 +715,6 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
716715
if merge and secrets and self._configuration.secrets:
717716
secrets = self._configuration.secrets + secrets
718717

719-
if secrets:
720-
secrets = secret_utils.convert_to_secret_ids(secrets)
721-
722718
values = dict_utils.remove_none_values(
723719
{
724720
"enable_cache": enable_cache,

src/zenml/utils/env_utils.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,10 @@ def temporary_environment(environment: Dict[str, str]) -> Iterator[None]:
203203
os.environ[key] = previous_value
204204

205205

206-
def gather_step_environment(
206+
def get_step_environment(
207207
step_config: "StepConfiguration", stack: "Stack"
208208
) -> Dict[str, str]:
209-
"""Gather the environment variables for a step.
209+
"""Get the environment variables for a step.
210210
211211
Args:
212212
step_config: The step configuration.
@@ -216,28 +216,46 @@ def gather_step_environment(
216216
A dictionary of environment variables.
217217
"""
218218
environment = {}
219-
secrets = []
220219
for component in stack.components.values():
221220
environment.update(component.environment)
222-
secrets.extend(component.secrets)
223221

224222
environment.update(stack.environment)
223+
environment.update(step_config.environment)
224+
225+
return environment
226+
227+
228+
def get_step_secret_environment(
229+
step_config: "StepConfiguration", stack: "Stack"
230+
) -> Dict[str, str]:
231+
"""Get the environment variables for a step.
232+
233+
Args:
234+
step_config: The step configuration.
235+
stack: The stack on which the step will run.
236+
237+
Returns:
238+
A dictionary of environment variables.
239+
"""
240+
secrets = step_config.secrets
225241
secrets.extend(stack.secrets)
226242

227-
environment.update(step_config.environment)
228-
secrets.extend(step_config.secrets)
243+
for component in stack.components.values():
244+
secrets.extend(component.secrets)
229245

230-
# Remove duplicates while preserving order, only the last occurrence of
231-
# each secret will be used to handle overrides
232-
secrets = list(reversed(dict.fromkeys(reversed(secrets))))
246+
# Removes duplicates while preserving order, only the first occurrence of
247+
# each secret will be kept. We then reverse the list to make sure the
248+
# overrides are applied in the correct order.
249+
secrets = list(reversed(dict.fromkeys(secrets)))
233250

251+
environment = {}
234252
for secret_name_or_id in secrets:
235253
try:
236254
secret = Client().get_secret(secret_name_or_id)
237255
except Exception as e:
238256
logger.warning(
239257
"Failed to get secret `%s` with error: %s. Skipping setting "
240-
"environment variable for this secret.",
258+
"environment variables for this secret.",
241259
secret_name_or_id,
242260
e,
243261
)
@@ -247,7 +265,7 @@ def gather_step_environment(
247265
logger.warning(
248266
"Did not find any secret values for secret `%s`. This might be "
249267
"because you do not have permissions to read the secret "
250-
"values. Skipping setting environment variable for this "
268+
"values. Skipping setting environment variables for this "
251269
"secret.",
252270
secret_name_or_id,
253271
)

src/zenml/utils/secret_utils.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,21 @@
1414
"""Utility functions for secrets and secret references."""
1515

1616
import re
17-
from typing import TYPE_CHECKING, Any, List, NamedTuple, Union
17+
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Union
1818

1919
from pydantic import Field, PlainSerializer, SecretStr
2020
from typing_extensions import Annotated
2121

2222
from zenml.logger import get_logger
23+
from zenml.utils import uuid_utils
2324

2425
if TYPE_CHECKING:
2526
from uuid import UUID
2627

2728
from pydantic.fields import FieldInfo
2829

30+
from zenml.zen_stores.zen_store_interface import ZenStoreInterface
31+
2932
_secret_reference_expression = re.compile(r"\{\{\s*\S+?\.\S+\s*\}\}")
3033

3134
PYDANTIC_SENSITIVE_FIELD_MARKER = "sensitive"
@@ -186,22 +189,47 @@ def is_clear_text_field(field: "FieldInfo") -> bool:
186189
return False
187190

188191

189-
def convert_to_secret_ids(secrets: List[Union[str, "UUID"]]) -> List["UUID"]:
192+
def resolve_and_verify_secrets(
193+
secrets: List[Union[str, "UUID"]],
194+
zen_store: Optional["ZenStoreInterface"] = None,
195+
) -> List["UUID"]:
190196
"""Convert a list of secret names or IDs to a list of secret IDs.
191197
192198
Args:
193199
secrets: A list of secret names or IDs.
200+
zen_store: The ZenML store to use to resolve the secrets.
194201
195202
Returns:
196203
A list of secret IDs.
197204
"""
198-
from zenml.client import Client
205+
if zen_store:
206+
from zenml.models import SecretFilter
199207

200-
secret_ids = []
201-
for secret in secrets:
202-
if isinstance(secret, str):
203-
secret_ids.append(Client().get_secret(secret, hydrate=False).id)
204-
else:
205-
secret_ids.append(secret)
208+
resolved_secrets = []
206209

207-
return secret_ids
210+
for secret_name_or_id in secrets:
211+
if uuid_utils.is_valid_uuid(secret_name_or_id):
212+
secret_id = (
213+
secret_name_or_id
214+
if isinstance(secret_name_or_id, UUID)
215+
else UUID(secret_name_or_id)
216+
)
217+
secret = zen_store.get_secret(secret_id, hydrate=False)
218+
resolved_secrets.append(secret.id)
219+
else:
220+
filter_model = SecretFilter(name=secret_name_or_id)
221+
secrets = zen_store.list_secrets(filter_model=filter_model)
222+
if not secrets:
223+
raise KeyError(
224+
f"Secret with name {secret_name_or_id} not found."
225+
)
226+
227+
resolved_secrets.append(secrets[0].id)
228+
229+
return resolved_secrets
230+
else:
231+
from zenml.client import Client
232+
233+
return [
234+
Client().get_secret(secret, hydrate=False).id for secret in secrets
235+
]

0 commit comments

Comments
 (0)