24
24
25
25
from sagemaker ._studio import _append_project_tags
26
26
from sagemaker .session import Session
27
- from sagemaker .workflow .callback_step import CallbackOutput
27
+ from sagemaker .workflow .callback_step import CallbackOutput , CallbackStep
28
28
from sagemaker .workflow .entities import (
29
29
Entity ,
30
30
Expression ,
@@ -240,9 +240,12 @@ def definition(self) -> str:
240
240
"""Converts a request structure to string representation for workflow service calls."""
241
241
request_dict = self .to_request ()
242
242
request_dict ["PipelineExperimentConfig" ] = interpolate (
243
- request_dict ["PipelineExperimentConfig" ]
243
+ request_dict ["PipelineExperimentConfig" ], {}
244
+ )
245
+ callback_output_to_step_map = _map_callback_outputs (self .steps )
246
+ request_dict ["Steps" ] = interpolate (
247
+ request_dict ["Steps" ], callback_output_to_step_map = callback_output_to_step_map
244
248
)
245
- request_dict ["Steps" ] = interpolate (request_dict ["Steps" ])
246
249
247
250
return json .dumps (request_dict )
248
251
@@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263
266
return [{"Name" : name , "Value" : str (value )} for name , value in parameters .items ()]
264
267
265
268
266
- def interpolate (request_obj : RequestType ) -> RequestType :
269
+ def interpolate (
270
+ request_obj : RequestType , callback_output_to_step_map : Dict [str , str ]
271
+ ) -> RequestType :
267
272
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268
273
269
274
Args:
270
275
request_obj (RequestType): The request dict.
276
+ callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
271
277
272
278
Returns:
273
279
RequestType: The request dict with Parameter values replaced by their expression.
274
280
"""
275
281
request_obj_copy = deepcopy (request_obj )
276
- return _interpolate (request_obj_copy )
282
+ return _interpolate (request_obj_copy , callback_output_to_step_map = callback_output_to_step_map )
277
283
278
284
279
- def _interpolate (obj : Union [RequestType , Any ]):
285
+ def _interpolate (obj : Union [RequestType , Any ], callback_output_to_step_map : Dict [ str , str ] ):
280
286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281
287
282
288
Args:
283
289
obj (Union[RequestType, Any]): The request dict.
290
+ callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
284
291
"""
285
- if isinstance (obj , (Expression , Parameter , Properties , CallbackOutput )):
292
+ if isinstance (obj , (Expression , Parameter , Properties )):
286
293
return obj .expr
294
+ if isinstance (obj , CallbackOutput ):
295
+ step_name = callback_output_to_step_map [obj .output_name ]
296
+ return obj .expr (step_name )
287
297
if isinstance (obj , dict ):
288
298
new = obj .__class__ ()
289
299
for key , value in obj .items ():
290
- new [key ] = interpolate (value )
300
+ new [key ] = interpolate (value , callback_output_to_step_map )
291
301
elif isinstance (obj , (list , set , tuple )):
292
- new = obj .__class__ (interpolate (value ) for value in obj )
302
+ new = obj .__class__ (interpolate (value , callback_output_to_step_map ) for value in obj )
293
303
else :
294
304
return obj
295
305
return new
296
306
297
307
308
+ def _map_callback_outputs (steps : List [Step ]):
309
+ """Iterate over the provided steps, building a map of callback output parameters to step names.
310
+
311
+ Args:
312
+ step (List[Step]): The steps list.
313
+ """
314
+
315
+ callback_output_map = {}
316
+ for step in steps :
317
+ if isinstance (step , CallbackStep ):
318
+ if step .outputs :
319
+ for output in step .outputs :
320
+ callback_output_map [output .output_name ] = step .name
321
+
322
+ return callback_output_map
323
+
324
+
298
325
def update_args (args : Dict [str , Any ], ** kwargs ):
299
326
"""Updates the request arguments dict with a value, if populated.
300
327
0 commit comments