Skip to content

Commit 9e2f78e

Browse files
committed
Copy sagemaker env to respect step specific settings
1 parent c8accde commit 9e2f78e

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,13 @@ def prepare_or_run_pipeline(
309309
env=environment,
310310
)
311311

312+
environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
313+
ExecutionVariables.PIPELINE_EXECUTION_ARN
314+
)
315+
312316
sagemaker_steps = []
313317
for step_name, step in deployment.step_configurations.items():
318+
step_environment = environment.copy()
314319
image = self.get_image(deployment=deployment, step_name=step_name)
315320
command = SagemakerEntrypointConfiguration.get_entrypoint_command()
316321
arguments = (
@@ -324,22 +329,18 @@ def prepare_or_run_pipeline(
324329
SagemakerOrchestratorSettings, self.get_settings(step)
325330
)
326331

327-
environment[ENV_ZENML_SAGEMAKER_RUN_ID] = (
328-
ExecutionVariables.PIPELINE_EXECUTION_ARN
329-
)
330-
331332
if step_settings.environment:
332-
step_environment = step_settings.environment.copy()
333+
user_defined_environment = step_settings.environment.copy()
333334
# Sagemaker does not allow environment variables longer than 256
334335
# characters to be passed to Processor steps. If an environment variable
335336
# is longer than 256 characters, we split it into multiple environment
336337
# variables (chunks) and re-construct it on the other side using the
337338
# custom entrypoint configuration.
338339
split_environment_variables(
339340
size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
340-
env=step_environment,
341+
env=user_defined_environment,
341342
)
342-
environment.update(step_environment)
343+
step_environment.update(user_defined_environment)
343344

344345
use_training_step = (
345346
step_settings.use_training_step
@@ -476,19 +477,19 @@ def prepare_or_run_pipeline(
476477
)
477478

478479
# Convert environment to a dict of strings
479-
environment = {
480+
step_environment = {
480481
key: str(value)
481482
if not isinstance(value, ExecutionVariable)
482483
else value
483-
for key, value in environment.items()
484+
for key, value in step_environment.items()
484485
}
485486

486487
if use_training_step:
487488
# Create Estimator and TrainingStep
488489
estimator = sagemaker.estimator.Estimator(
489490
keep_alive_period_in_seconds=step_settings.keep_alive_period_in_seconds,
490491
output_path=output_path,
491-
environment=environment,
492+
environment=step_environment,
492493
container_entry_point=entrypoint,
493494
**args_for_step_executor,
494495
)
@@ -502,7 +503,7 @@ def prepare_or_run_pipeline(
502503
# Create Processor and ProcessingStep
503504
processor = sagemaker.processing.Processor(
504505
entrypoint=entrypoint,
505-
env=environment,
506+
env=step_environment,
506507
**args_for_step_executor,
507508
)
508509

0 commit comments

Comments
 (0)