@@ -30,6 +30,7 @@ class Context:
30
30
property_references : Dict [str , str ] = field (default_factory = dict )
31
31
serialize_output_to_json : bool = False
32
32
func_step_s3_dir : str = None
33
+ s3_base_uri : str = None
33
34
34
35
35
36
@dataclass
@@ -77,6 +78,17 @@ class _ExecutionVariable:
77
78
name : str
78
79
79
80
81
+ @dataclass
82
+ class _S3BaseUriIdentifier :
83
+ """Identifies that the class refers to function step s3 base uri.
84
+
85
+ The s3_base_uri = s3_root_uri + pipeline_name.
86
+ This identifier is resolved in function step runtime by SDK.
87
+ """
88
+
89
+ NAME = "S3_BASE_URI"
90
+
91
+
80
92
@dataclass
81
93
class _DelayedReturn :
82
94
"""Delayed return from a function."""
@@ -155,6 +167,7 @@ def __init__(
155
167
hmac_key : str ,
156
168
parameter_resolver : _ParameterResolver ,
157
169
execution_variable_resolver : _ExecutionVariableResolver ,
170
+ s3_base_uri : str ,
158
171
** settings ,
159
172
):
160
173
"""Resolve delayed return.
@@ -164,8 +177,11 @@ def __init__(
164
177
hmac_key: key used to encrypt serialized and deserialized function and arguments.
165
178
parameter_resolver: resolver used to pipeline parameters.
166
179
execution_variable_resolver: resolver used to resolve execution variables.
180
+ s3_base_uri (str): the s3 base uri of the function step that
181
+ the DelayedReturn object associates with.
167
182
**settings: settings to pass to the deserialization function.
168
183
"""
184
+ self ._s3_base_uri = s3_base_uri
169
185
self ._parameter_resolver = parameter_resolver
170
186
self ._execution_variable_resolver = execution_variable_resolver
171
187
# different delayed returns can have the same uri, so we need to dedupe
@@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
205
221
uri .append (self ._parameter_resolver .resolve (component ))
206
222
elif isinstance (component , _ExecutionVariable ):
207
223
uri .append (self ._execution_variable_resolver .resolve (component ))
224
+ elif isinstance (component , _S3BaseUriIdentifier ):
225
+ uri .append (self ._s3_base_uri )
208
226
else :
209
227
uri .append (component )
210
228
return s3_path_join (* uri )
@@ -251,6 +269,7 @@ def resolve_pipeline_variables(
251
269
hmac_key = hmac_key ,
252
270
parameter_resolver = parameter_resolver ,
253
271
execution_variable_resolver = execution_variable_resolver ,
272
+ s3_base_uri = context .s3_base_uri ,
254
273
** settings ,
255
274
)
256
275
@@ -289,11 +308,10 @@ def resolve_pipeline_variables(
289
308
return resolved_func_args , resolved_func_kwargs
290
309
291
310
292
- def convert_pipeline_variables_to_pickleable (s3_base_uri : str , func_args : Tuple , func_kwargs : Dict ):
311
+ def convert_pipeline_variables_to_pickleable (func_args : Tuple , func_kwargs : Dict ):
293
312
"""Convert pipeline variables to pickleable.
294
313
295
314
Args:
296
- s3_base_uri: s3 base uri where artifacts are stored.
297
315
func_args: function args.
298
316
func_kwargs: function kwargs.
299
317
"""
@@ -304,11 +322,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,
304
322
305
323
from sagemaker .workflow .function_step import DelayedReturn
306
324
325
+ # Notes:
326
+ # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
327
+ # when defining function steps and after function. After step-level arg serialization,
328
+ # it's hard to update the s3_base_uri in pipeline compile time.
329
+ # Thus set a placeholder _S3BaseUriIdentifier here and let the runtime job to resolve it.
330
+ # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
331
+ # the sagemaker_session is not passed in the pipeline but the default s3_root_uri
332
+ # should be retrieved from the pipeline's sagemaker_session.
307
333
def convert (arg ):
308
334
if isinstance (arg , DelayedReturn ):
309
335
return _DelayedReturn (
310
336
uri = [
311
- s3_base_uri ,
337
+ _S3BaseUriIdentifier () ,
312
338
ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
313
339
arg ._step .name ,
314
340
"results" ,
0 commit comments