@@ -77,6 +77,17 @@ class _ExecutionVariable:
77
77
name : str
78
78
79
79
80
+ @dataclass
81
+ class _S3BaseUriIdentifier :
82
+ """Identifies that the class refers to function step s3 base uri.
83
+
84
+ The s3_base_uri = s3_root_uri + pipeline_name.
85
+ This identifier is resolved in function step runtime by SDK.
86
+ """
87
+
88
+ NAME = "S3_BASE_URI"
89
+
90
+
80
91
@dataclass
81
92
class _DelayedReturn :
82
93
"""Delayed return from a function."""
@@ -155,6 +166,7 @@ def __init__(
155
166
hmac_key : str ,
156
167
parameter_resolver : _ParameterResolver ,
157
168
execution_variable_resolver : _ExecutionVariableResolver ,
169
+ s3_base_uri : str ,
158
170
** settings ,
159
171
):
160
172
"""Resolve delayed return.
@@ -164,8 +176,12 @@ def __init__(
164
176
hmac_key: key used to encrypt serialized and deserialized function and arguments.
165
177
parameter_resolver: resolver used to pipeline parameters.
166
178
execution_variable_resolver: resolver used to resolve execution variables.
179
+ s3_base_uri (str): the s3 base uri of the function step that
180
+ the serialized artifacts will be uploaded to.
181
+ The s3_base_uri = s3_root_uri + pipeline_name.
167
182
**settings: settings to pass to the deserialization function.
168
183
"""
184
+ self ._s3_base_uri = s3_base_uri
169
185
self ._parameter_resolver = parameter_resolver
170
186
self ._execution_variable_resolver = execution_variable_resolver
171
187
# different delayed returns can have the same uri, so we need to dedupe
@@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
205
221
uri .append (self ._parameter_resolver .resolve (component ))
206
222
elif isinstance (component , _ExecutionVariable ):
207
223
uri .append (self ._execution_variable_resolver .resolve (component ))
224
+ elif isinstance (component , _S3BaseUriIdentifier ):
225
+ uri .append (self ._s3_base_uri )
208
226
else :
209
227
uri .append (component )
210
228
return s3_path_join (* uri )
@@ -219,7 +237,12 @@ def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any):
219
237
220
238
221
239
def resolve_pipeline_variables (
222
- context : Context , func_args : Tuple , func_kwargs : Dict , hmac_key : str , ** settings
240
+ context : Context ,
241
+ func_args : Tuple ,
242
+ func_kwargs : Dict ,
243
+ hmac_key : str ,
244
+ s3_base_uri : str ,
245
+ ** settings ,
223
246
):
224
247
"""Resolve pipeline variables.
225
248
@@ -228,6 +251,8 @@ def resolve_pipeline_variables(
228
251
func_args: function args.
229
252
func_kwargs: function kwargs.
230
253
hmac_key: key used to encrypt serialized and deserialized function and arguments.
254
+ s3_base_uri: the s3 base uri of the function step that the serialized artifacts
255
+ will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
231
256
**settings: settings to pass to the deserialization function.
232
257
"""
233
258
@@ -251,6 +276,7 @@ def resolve_pipeline_variables(
251
276
hmac_key = hmac_key ,
252
277
parameter_resolver = parameter_resolver ,
253
278
execution_variable_resolver = execution_variable_resolver ,
279
+ s3_base_uri = s3_base_uri ,
254
280
** settings ,
255
281
)
256
282
@@ -289,11 +315,10 @@ def resolve_pipeline_variables(
289
315
return resolved_func_args , resolved_func_kwargs
290
316
291
317
292
- def convert_pipeline_variables_to_pickleable (s3_base_uri : str , func_args : Tuple , func_kwargs : Dict ):
318
+ def convert_pipeline_variables_to_pickleable (func_args : Tuple , func_kwargs : Dict ):
293
319
"""Convert pipeline variables to pickleable.
294
320
295
321
Args:
296
- s3_base_uri: s3 base uri where artifacts are stored.
297
322
func_args: function args.
298
323
func_kwargs: function kwargs.
299
324
"""
@@ -304,11 +329,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,
304
329
305
330
from sagemaker .workflow .function_step import DelayedReturn
306
331
332
+ # Notes:
333
+ # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334
+ # when defining function steps. After step-level arg serialization,
335
+ # it's hard to update the s3_base_uri in pipeline compile time.
336
+ # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337
+ # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338
+ # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339
+ # should be retrieved from the pipeline's sagemaker_session.
307
340
def convert (arg ):
308
341
if isinstance (arg , DelayedReturn ):
309
342
return _DelayedReturn (
310
343
uri = [
311
- s3_base_uri ,
344
+ _S3BaseUriIdentifier () ,
312
345
ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
313
346
arg ._step .name ,
314
347
"results" ,
0 commit comments