Skip to content

Commit 056f71a

Browse files
committed
fix: Move func and args serialization of function step to step level
1 parent 2432b26 commit 056f71a

File tree

10 files changed

+328
-85
lines changed

10 files changed

+328
-85
lines changed

src/sagemaker/remote_function/core/pipeline_variables.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,17 @@ class _ExecutionVariable:
7777
name: str
7878

7979

80+
@dataclass
81+
class _S3BaseUriIdentifier:
82+
"""Identifies that the class refers to function step s3 base uri.
83+
84+
The s3_base_uri = s3_root_uri + pipeline_name.
85+
This identifier is resolved in function step runtime by SDK.
86+
"""
87+
88+
NAME = "S3_BASE_URI"
89+
90+
8091
@dataclass
8192
class _DelayedReturn:
8293
"""Delayed return from a function."""
@@ -155,6 +166,7 @@ def __init__(
155166
hmac_key: str,
156167
parameter_resolver: _ParameterResolver,
157168
execution_variable_resolver: _ExecutionVariableResolver,
169+
s3_base_uri: str,
158170
**settings,
159171
):
160172
"""Resolve delayed return.
@@ -164,8 +176,12 @@ def __init__(
164176
hmac_key: key used to encrypt serialized and deserialized function and arguments.
165177
parameter_resolver: resolver used to pipeline parameters.
166178
execution_variable_resolver: resolver used to resolve execution variables.
179+
s3_base_uri (str): the s3 base uri of the function step that
180+
the serialized artifacts will be uploaded to.
181+
The s3_base_uri = s3_root_uri + pipeline_name.
167182
**settings: settings to pass to the deserialization function.
168183
"""
184+
self._s3_base_uri = s3_base_uri
169185
self._parameter_resolver = parameter_resolver
170186
self._execution_variable_resolver = execution_variable_resolver
171187
# different delayed returns can have the same uri, so we need to dedupe
@@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
205221
uri.append(self._parameter_resolver.resolve(component))
206222
elif isinstance(component, _ExecutionVariable):
207223
uri.append(self._execution_variable_resolver.resolve(component))
224+
elif isinstance(component, _S3BaseUriIdentifier):
225+
uri.append(self._s3_base_uri)
208226
else:
209227
uri.append(component)
210228
return s3_path_join(*uri)
@@ -219,7 +237,12 @@ def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any):
219237

220238

221239
def resolve_pipeline_variables(
222-
context: Context, func_args: Tuple, func_kwargs: Dict, hmac_key: str, **settings
240+
context: Context,
241+
func_args: Tuple,
242+
func_kwargs: Dict,
243+
hmac_key: str,
244+
s3_base_uri: str,
245+
**settings,
223246
):
224247
"""Resolve pipeline variables.
225248
@@ -228,6 +251,8 @@ def resolve_pipeline_variables(
228251
func_args: function args.
229252
func_kwargs: function kwargs.
230253
hmac_key: key used to encrypt serialized and deserialized function and arguments.
254+
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
255+
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
231256
**settings: settings to pass to the deserialization function.
232257
"""
233258

@@ -251,6 +276,7 @@ def resolve_pipeline_variables(
251276
hmac_key=hmac_key,
252277
parameter_resolver=parameter_resolver,
253278
execution_variable_resolver=execution_variable_resolver,
279+
s3_base_uri=s3_base_uri,
254280
**settings,
255281
)
256282

@@ -289,11 +315,10 @@ def resolve_pipeline_variables(
289315
return resolved_func_args, resolved_func_kwargs
290316

291317

292-
def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple, func_kwargs: Dict):
318+
def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict):
293319
"""Convert pipeline variables to pickleable.
294320
295321
Args:
296-
s3_base_uri: s3 base uri where artifacts are stored.
297322
func_args: function args.
298323
func_kwargs: function kwargs.
299324
"""
@@ -304,11 +329,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,
304329

305330
from sagemaker.workflow.function_step import DelayedReturn
306331

332+
# Notes:
333+
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334+
# when defining function steps. After step-level arg serialization,
335+
# it's hard to update the s3_base_uri in pipeline compile time.
336+
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337+
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338+
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339+
# should be retrieved from the pipeline's sagemaker_session.
307340
def convert(arg):
308341
if isinstance(arg, DelayedReturn):
309342
return _DelayedReturn(
310343
uri=[
311-
s3_base_uri,
344+
_S3BaseUriIdentifier(),
312345
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
313346
arg._step.name,
314347
"results",

src/sagemaker/remote_function/core/serialization.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,13 @@ def serialize_func_to_s3(
161161
Raises:
162162
SerializationError: when fail to serialize function to bytes.
163163
"""
164-
bytes_to_upload = CloudpickleSerializer.serialize(func)
165164

166-
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
167-
168-
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
169-
170-
_upload_bytes_to_s3(
171-
_MetaData(sha256_hash).to_json(),
172-
f"{s3_uri}/metadata.json",
173-
s3_kms_key,
174-
sagemaker_session,
165+
_upload_payload_and_metadata_to_s3(
166+
bytes_to_upload=CloudpickleSerializer.serialize(func),
167+
hmac_key=hmac_key,
168+
s3_uri=s3_uri,
169+
sagemaker_session=sagemaker_session,
170+
s3_kms_key=s3_kms_key,
175171
)
176172

177173

@@ -220,17 +216,12 @@ def serialize_obj_to_s3(
220216
SerializationError: when fail to serialize object to bytes.
221217
"""
222218

223-
bytes_to_upload = CloudpickleSerializer.serialize(obj)
224-
225-
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
226-
227-
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
228-
229-
_upload_bytes_to_s3(
230-
_MetaData(sha256_hash).to_json(),
231-
f"{s3_uri}/metadata.json",
232-
s3_kms_key,
233-
sagemaker_session,
219+
_upload_payload_and_metadata_to_s3(
220+
bytes_to_upload=CloudpickleSerializer.serialize(obj),
221+
hmac_key=hmac_key,
222+
s3_uri=s3_uri,
223+
sagemaker_session=sagemaker_session,
224+
s3_kms_key=s3_kms_key,
234225
)
235226

236227

@@ -318,8 +309,32 @@ def serialize_exception_to_s3(
318309
"""
319310
pickling_support.install()
320311

321-
bytes_to_upload = CloudpickleSerializer.serialize(exc)
312+
_upload_payload_and_metadata_to_s3(
313+
bytes_to_upload=CloudpickleSerializer.serialize(exc),
314+
hmac_key=hmac_key,
315+
s3_uri=s3_uri,
316+
sagemaker_session=sagemaker_session,
317+
s3_kms_key=s3_kms_key,
318+
)
322319

320+
321+
def _upload_payload_and_metadata_to_s3(
322+
bytes_to_upload: Union[bytes, io.BytesIO],
323+
hmac_key: str,
324+
s3_uri: str,
325+
sagemaker_session: Session,
326+
s3_kms_key,
327+
):
328+
"""Uploads serialized payload and metadata to s3.
329+
330+
Args:
331+
bytes_to_upload (bytes): Serialized bytes to upload.
332+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
333+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
334+
sagemaker_session (sagemaker.session.Session):
335+
The underlying Boto3 session which AWS service calls are delegated to.
336+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
337+
"""
323338
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
324339

325340
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

src/sagemaker/remote_function/core/stored_function.py

+42
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
from dataclasses import dataclass
1718
from typing import Any
1819

1920

@@ -36,6 +37,14 @@
3637
JSON_RESULTS_FILE = "results.json"
3738

3839

40+
@dataclass
41+
class _SerializedData:
42+
"""Data class to store serialized function and arguments"""
43+
44+
func: bytes
45+
args: bytes
46+
47+
3948
class StoredFunction:
4049
"""Class representing a remote function stored in S3."""
4150

@@ -105,6 +114,38 @@ def save(self, func, *args, **kwargs):
105114
s3_kms_key=self.s3_kms_key,
106115
)
107116

117+
def save_pipeline_step_function(self, serialized_data):
118+
"""Upload serialized function and arguments to s3.
119+
120+
Args:
121+
serialized_data (_SerializedData): The serialized function
122+
and function arguments of a function step.
123+
"""
124+
125+
logger.info(
126+
"Uploading serialized function code to %s",
127+
s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
128+
)
129+
serialization._upload_payload_and_metadata_to_s3(
130+
bytes_to_upload=serialized_data.func,
131+
hmac_key=self.hmac_key,
132+
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
133+
sagemaker_session=self.sagemaker_session,
134+
s3_kms_key=self.s3_kms_key,
135+
)
136+
137+
logger.info(
138+
"Uploading serialized function arguments to %s",
139+
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
140+
)
141+
serialization._upload_payload_and_metadata_to_s3(
142+
bytes_to_upload=serialized_data.args,
143+
hmac_key=self.hmac_key,
144+
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
145+
sagemaker_session=self.sagemaker_session,
146+
s3_kms_key=self.s3_kms_key,
147+
)
148+
108149
def load_and_invoke(self) -> Any:
109150
"""Load and deserialize the function and the arguments and then execute it."""
110151

@@ -134,6 +175,7 @@ def load_and_invoke(self) -> Any:
134175
args,
135176
kwargs,
136177
hmac_key=self.hmac_key,
178+
s3_base_uri=self.s3_base_uri,
137179
sagemaker_session=self.sagemaker_session,
138180
)
139181

src/sagemaker/remote_function/job.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@
5252
from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config
5353
from sagemaker.s3 import s3_path_join, S3Uploader
5454
from sagemaker import vpc_utils
55-
from sagemaker.remote_function.core.stored_function import StoredFunction
56-
from sagemaker.remote_function.core.pipeline_variables import (
57-
Context,
58-
convert_pipeline_variables_to_pickleable,
59-
)
55+
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
56+
from sagemaker.remote_function.core.pipeline_variables import Context
6057
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
6158
RuntimeEnvironmentManager,
6259
_DependencySettings,
@@ -695,6 +692,7 @@ def compile(
695692
func_args: tuple,
696693
func_kwargs: dict,
697694
run_info=None,
695+
serialized_data: _SerializedData = None,
698696
) -> dict:
699697
"""Build the artifacts and generate the training job request."""
700698
from sagemaker.workflow.properties import Properties
@@ -732,12 +730,8 @@ def compile(
732730
func_step_s3_dir=step_compilation_context.pipeline_build_time,
733731
),
734732
)
735-
converted_func_args, converted_func_kwargs = convert_pipeline_variables_to_pickleable(
736-
s3_base_uri=s3_base_uri,
737-
func_args=func_args,
738-
func_kwargs=func_kwargs,
739-
)
740-
stored_function.save(func, *converted_func_args, **converted_func_kwargs)
733+
734+
stored_function.save_pipeline_step_function(serialized_data)
741735

742736
stopping_condition = {
743737
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,

src/sagemaker/workflow/function_step.py

+21
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def __init__(
8383
func_kwargs (dict): keyword arguments of the python function.
8484
**kwargs: Additional arguments to be passed to the `step` decorator.
8585
"""
86+
from sagemaker.remote_function.core.pipeline_variables import (
87+
convert_pipeline_variables_to_pickleable,
88+
)
89+
from sagemaker.remote_function.core.serialization import CloudpickleSerializer
90+
from sagemaker.remote_function.core.stored_function import _SerializedData
8691

8792
super(_FunctionStep, self).__init__(
8893
name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies
@@ -96,6 +101,21 @@ def __init__(
96101

97102
self.__job_settings = None
98103

104+
(
105+
self._converted_func_args,
106+
self._converted_func_kwargs,
107+
) = convert_pipeline_variables_to_pickleable(
108+
func_args=self._func_args,
109+
func_kwargs=self._func_kwargs,
110+
)
111+
112+
self._serialized_data = _SerializedData(
113+
func=CloudpickleSerializer.serialize(self._func),
114+
args=CloudpickleSerializer.serialize(
115+
(self._converted_func_args, self._converted_func_kwargs)
116+
),
117+
)
118+
99119
@property
100120
def func(self):
101121
"""The python function to run as a pipeline step."""
@@ -185,6 +205,7 @@ def arguments(self) -> RequestType:
185205
func=self.func,
186206
func_args=self.func_args,
187207
func_kwargs=self.func_kwargs,
208+
serialized_data=self._serialized_data,
188209
)
189210
# Continue to pop job name if not explicitly opted-in via config
190211
request_dict = trim_request_dict(request_dict, "TrainingJobName", step_compilation_context)

0 commit comments

Comments
 (0)