Skip to content

Commit 8013526

Browse files
authored
Merge branch 'master' into fix/kms_key_does_not_propagate_in_register_model
2 parents 79b864a + a05b10b commit 8013526

File tree

4 files changed

+85
-16
lines changed

4 files changed

+85
-16
lines changed

src/sagemaker/workflow/callback_step.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,20 @@ def to_request(self) -> RequestType:
5858
"OutputType": self.output_type.value,
5959
}
6060

61-
@property
62-
def expr(self) -> Dict[str, str]:
63-
"""The 'Get' expression dict for a `Parameter`."""
64-
return CallbackOutput._expr(self.output_name)
61+
def expr(self, step_name) -> Dict[str, str]:
62+
"""The 'Get' expression dict for a `CallbackOutput`."""
63+
return CallbackOutput._expr(self.output_name, step_name)
6564

6665
@classmethod
67-
def _expr(cls, name):
66+
def _expr(cls, name, step_name):
6867
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
6968
7069
Args:
7170
name (str): The name of the callback output.
71+
step_name (str): The name of the step the callback step associated
72+
with this output belongs to.
7273
"""
73-
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
74+
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}
7475

7576

7677
class CallbackStep(Step):

src/sagemaker/workflow/pipeline.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27-
from sagemaker.workflow.callback_step import CallbackOutput
27+
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
2828
from sagemaker.workflow.entities import (
2929
Entity,
3030
Expression,
@@ -240,9 +240,12 @@ def definition(self) -> str:
240240
"""Converts a request structure to string representation for workflow service calls."""
241241
request_dict = self.to_request()
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243-
request_dict["PipelineExperimentConfig"]
243+
request_dict["PipelineExperimentConfig"], {}
244+
)
245+
callback_output_to_step_map = _map_callback_outputs(self.steps)
246+
request_dict["Steps"] = interpolate(
247+
request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map
244248
)
245-
request_dict["Steps"] = interpolate(request_dict["Steps"])
246249

247250
return json.dumps(request_dict)
248251

@@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263266
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]
264267

265268

266-
def interpolate(request_obj: RequestType) -> RequestType:
269+
def interpolate(
270+
request_obj: RequestType, callback_output_to_step_map: Dict[str, str]
271+
) -> RequestType:
267272
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268273
269274
Args:
270275
request_obj (RequestType): The request dict.
276+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
271277
272278
Returns:
273279
RequestType: The request dict with Parameter values replaced by their expression.
274280
"""
275281
request_obj_copy = deepcopy(request_obj)
276-
return _interpolate(request_obj_copy)
282+
return _interpolate(request_obj_copy, callback_output_to_step_map=callback_output_to_step_map)
277283

278284

279-
def _interpolate(obj: Union[RequestType, Any]):
285+
def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict[str, str]):
280286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281287
282288
Args:
283289
obj (Union[RequestType, Any]): The request dict.
290+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
284291
"""
285-
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
292+
if isinstance(obj, (Expression, Parameter, Properties)):
286293
return obj.expr
294+
if isinstance(obj, CallbackOutput):
295+
step_name = callback_output_to_step_map[obj.output_name]
296+
return obj.expr(step_name)
287297
if isinstance(obj, dict):
288298
new = obj.__class__()
289299
for key, value in obj.items():
290-
new[key] = interpolate(value)
300+
new[key] = interpolate(value, callback_output_to_step_map)
291301
elif isinstance(obj, (list, set, tuple)):
292-
new = obj.__class__(interpolate(value) for value in obj)
302+
new = obj.__class__(interpolate(value, callback_output_to_step_map) for value in obj)
293303
else:
294304
return obj
295305
return new
296306

297307

308+
def _map_callback_outputs(steps: List[Step]):
309+
"""Iterate over the provided steps, building a map of callback output parameters to step names.
310+
311+
Args:
312+
step (List[Step]): The steps list.
313+
"""
314+
315+
callback_output_map = {}
316+
for step in steps:
317+
if isinstance(step, CallbackStep):
318+
if step.outputs:
319+
for output in step.outputs:
320+
callback_output_map[output.output_name] = step.name
321+
322+
return callback_output_map
323+
324+
298325
def update_args(args: Dict[str, Any], **kwargs):
299326
"""Updates the request arguments dict with a value, if populated.
300327

tests/integ/test_workflow.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,47 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
740740
pass
741741

742742

743+
def test_two_step_callback_pipeline_with_output_reference(
744+
sagemaker_session, role, pipeline_name, region_name
745+
):
746+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
747+
748+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
749+
step_callback1 = CallbackStep(
750+
name="callback-step1",
751+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
752+
inputs={"arg1": "foo"},
753+
outputs=[outputParam1],
754+
)
755+
756+
step_callback2 = CallbackStep(
757+
name="callback-step2",
758+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
759+
inputs={"arg1": outputParam1},
760+
outputs=[],
761+
)
762+
763+
pipeline = Pipeline(
764+
name=pipeline_name,
765+
parameters=[instance_count],
766+
steps=[step_callback1, step_callback2],
767+
sagemaker_session=sagemaker_session,
768+
)
769+
770+
try:
771+
response = pipeline.create(role)
772+
create_arn = response["PipelineArn"]
773+
assert re.match(
774+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
775+
create_arn,
776+
)
777+
finally:
778+
try:
779+
pipeline.delete()
780+
except Exception:
781+
pass
782+
783+
743784
def test_conditional_pytorch_training_model_registration(
744785
sagemaker_session,
745786
role,

tests/unit/sagemaker/workflow/test_callback_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_pipeline_interpolates_callback_outputs():
8888
name="MyCallbackStep2",
8989
depends_on=["TestStep"],
9090
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
91-
inputs={"arg1": cb_step1.properties.Outputs["output1"]},
91+
inputs={"arg1": outputParam1},
9292
outputs=[outputParam2],
9393
)
9494

0 commit comments

Comments
 (0)