Skip to content

Commit e663de8

Browse files
committed
fix: Move func and args serialization of function step to step level
1 parent 8c2012b commit e663de8

File tree

11 files changed

+311
-81
lines changed

11 files changed

+311
-81
lines changed

src/sagemaker/remote_function/core/pipeline_variables.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Context:
3030
property_references: Dict[str, str] = field(default_factory=dict)
3131
serialize_output_to_json: bool = False
3232
func_step_s3_dir: str = None
33+
s3_base_uri: str = None
3334

3435

3536
@dataclass
@@ -77,6 +78,17 @@ class _ExecutionVariable:
7778
name: str
7879

7980

81+
@dataclass
82+
class _S3BaseUriIdentifier:
83+
"""Identifies that the class refers to function step s3 base uri.
84+
85+
The s3_base_uri = s3_root_uri + pipeline_name.
86+
This identifier is resolved in function step runtime by SDK.
87+
"""
88+
89+
NAME = "S3_BASE_URI"
90+
91+
8092
@dataclass
8193
class _DelayedReturn:
8294
"""Delayed return from a function."""
@@ -155,6 +167,7 @@ def __init__(
155167
hmac_key: str,
156168
parameter_resolver: _ParameterResolver,
157169
execution_variable_resolver: _ExecutionVariableResolver,
170+
s3_base_uri: str,
158171
**settings,
159172
):
160173
"""Resolve delayed return.
@@ -164,8 +177,11 @@ def __init__(
164177
hmac_key: key used to encrypt serialized and deserialized function and arguments.
165178
parameter_resolver: resolver used to pipeline parameters.
166179
execution_variable_resolver: resolver used to resolve execution variables.
180+
s3_base_uri (str): the s3 base uri of the function step that
181+
the DelayedReturn object associates with.
167182
**settings: settings to pass to the deserialization function.
168183
"""
184+
self._s3_base_uri = s3_base_uri
169185
self._parameter_resolver = parameter_resolver
170186
self._execution_variable_resolver = execution_variable_resolver
171187
# different delayed returns can have the same uri, so we need to dedupe
@@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
205221
uri.append(self._parameter_resolver.resolve(component))
206222
elif isinstance(component, _ExecutionVariable):
207223
uri.append(self._execution_variable_resolver.resolve(component))
224+
elif isinstance(component, _S3BaseUriIdentifier):
225+
uri.append(self._s3_base_uri)
208226
else:
209227
uri.append(component)
210228
return s3_path_join(*uri)
@@ -251,6 +269,7 @@ def resolve_pipeline_variables(
251269
hmac_key=hmac_key,
252270
parameter_resolver=parameter_resolver,
253271
execution_variable_resolver=execution_variable_resolver,
272+
s3_base_uri=context.s3_base_uri,
254273
**settings,
255274
)
256275

@@ -289,11 +308,10 @@ def resolve_pipeline_variables(
289308
return resolved_func_args, resolved_func_kwargs
290309

291310

292-
def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple, func_kwargs: Dict):
311+
def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict):
293312
"""Convert pipeline variables to pickleable.
294313
295314
Args:
296-
s3_base_uri: s3 base uri where artifacts are stored.
297315
func_args: function args.
298316
func_kwargs: function kwargs.
299317
"""
@@ -304,11 +322,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,
304322

305323
from sagemaker.workflow.function_step import DelayedReturn
306324

325+
# Notes:
326+
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
327+
# when defining function steps and after function. After step-level arg serialization,
328+
# it's hard to update the s3_base_uri in pipeline compile time.
329+
# Thus set a placeholder _S3BaseUriIdentifier here and let the runtime job to resolve it.
330+
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
331+
# the sagemaker_session is not passed in the pipeline but the default s3_root_uri
332+
# should be retrieved from the pipeline's sagemaker_session.
307333
def convert(arg):
308334
if isinstance(arg, DelayedReturn):
309335
return _DelayedReturn(
310336
uri=[
311-
s3_base_uri,
337+
_S3BaseUriIdentifier(),
312338
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
313339
arg._step.name,
314340
"results",

src/sagemaker/remote_function/core/serialization.py

+37-22
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,13 @@ def serialize_func_to_s3(
161161
Raises:
162162
SerializationError: when fail to serialize function to bytes.
163163
"""
164-
bytes_to_upload = CloudpickleSerializer.serialize(func)
165164

166-
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
167-
168-
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
169-
170-
_upload_bytes_to_s3(
171-
_MetaData(sha256_hash).to_json(),
172-
f"{s3_uri}/metadata.json",
173-
s3_kms_key,
174-
sagemaker_session,
165+
_upload_payload_and_metadata_to_s3(
166+
bytes_to_upload=CloudpickleSerializer.serialize(func),
167+
hmac_key=hmac_key,
168+
s3_uri=s3_uri,
169+
sagemaker_session=sagemaker_session,
170+
s3_kms_key=s3_kms_key,
175171
)
176172

177173

@@ -220,17 +216,12 @@ def serialize_obj_to_s3(
220216
SerializationError: when fail to serialize object to bytes.
221217
"""
222218

223-
bytes_to_upload = CloudpickleSerializer.serialize(obj)
224-
225-
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
226-
227-
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
228-
229-
_upload_bytes_to_s3(
230-
_MetaData(sha256_hash).to_json(),
231-
f"{s3_uri}/metadata.json",
232-
s3_kms_key,
233-
sagemaker_session,
219+
_upload_payload_and_metadata_to_s3(
220+
bytes_to_upload=CloudpickleSerializer.serialize(obj),
221+
hmac_key=hmac_key,
222+
s3_uri=s3_uri,
223+
sagemaker_session=sagemaker_session,
224+
s3_kms_key=s3_kms_key,
234225
)
235226

236227

@@ -318,8 +309,32 @@ def serialize_exception_to_s3(
318309
"""
319310
pickling_support.install()
320311

321-
bytes_to_upload = CloudpickleSerializer.serialize(exc)
312+
_upload_payload_and_metadata_to_s3(
313+
bytes_to_upload=CloudpickleSerializer.serialize(exc),
314+
hmac_key=hmac_key,
315+
s3_uri=s3_uri,
316+
sagemaker_session=sagemaker_session,
317+
s3_kms_key=s3_kms_key,
318+
)
322319

320+
321+
def _upload_payload_and_metadata_to_s3(
322+
bytes_to_upload: Union[bytes, io.BytesIO],
323+
hmac_key: str,
324+
s3_uri: str,
325+
sagemaker_session: Session,
326+
s3_kms_key,
327+
):
328+
"""Uploads serialized payload and metadata to s3.
329+
330+
Args:
331+
bytes_to_upload (bytes): Serialized bytes to upload.
332+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
333+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
334+
sagemaker_session (sagemaker.session.Session):
335+
The underlying Boto3 session which AWS service calls are delegated to.
336+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
337+
"""
323338
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
324339

325340
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

src/sagemaker/remote_function/core/stored_function.py

+41
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
from dataclasses import dataclass
1718
from typing import Any
1819

1920

@@ -36,6 +37,14 @@
3637
JSON_RESULTS_FILE = "results.json"
3738

3839

40+
@dataclass
41+
class _SerializedData:
42+
"""Data class to store serialized function and arguments"""
43+
44+
func: bytes
45+
args: bytes
46+
47+
3948
class StoredFunction:
4049
"""Class representing a remote function stored in S3."""
4150

@@ -105,6 +114,38 @@ def save(self, func, *args, **kwargs):
105114
s3_kms_key=self.s3_kms_key,
106115
)
107116

117+
def save_pipeline_step_function(self, serialized_data):
118+
"""Upload serialized function and arguments to s3.
119+
120+
Args:
121+
serialized_data (_SerializedData): The serialized function
122+
and function arguments of a function step.
123+
"""
124+
125+
logger.info(
126+
"Uploading serialized function code to %s",
127+
s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
128+
)
129+
serialization._upload_payload_and_metadata_to_s3(
130+
bytes_to_upload=serialized_data.func,
131+
hmac_key=self.hmac_key,
132+
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
133+
sagemaker_session=self.sagemaker_session,
134+
s3_kms_key=self.s3_kms_key,
135+
)
136+
137+
logger.info(
138+
"Uploading serialized function arguments to %s",
139+
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
140+
)
141+
serialization._upload_payload_and_metadata_to_s3(
142+
bytes_to_upload=serialized_data.args,
143+
hmac_key=self.hmac_key,
144+
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
145+
sagemaker_session=self.sagemaker_session,
146+
s3_kms_key=self.s3_kms_key,
147+
)
148+
108149
def load_and_invoke(self) -> Any:
109150
"""Load and deserialize the function and the arguments and then execute it."""
110151

src/sagemaker/remote_function/invoke_function.py

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def _load_pipeline_context(args) -> Context:
8484
property_references = args.property_references
8585
serialize_output_to_json = args.serialize_output_to_json
8686
func_step_s3_dir = args.func_step_s3_dir
87+
s3_base_uri = args.s3_base_uri
8788

8889
property_references_dict = {}
8990
for i in range(0, len(property_references), 2):
@@ -94,6 +95,7 @@ def _load_pipeline_context(args) -> Context:
9495
property_references=property_references_dict,
9596
serialize_output_to_json=serialize_output_to_json,
9697
func_step_s3_dir=func_step_s3_dir,
98+
s3_base_uri=s3_base_uri,
9799
)
98100

99101

src/sagemaker/remote_function/job.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,8 @@
5252
from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config
5353
from sagemaker.s3 import s3_path_join, S3Uploader
5454
from sagemaker import vpc_utils
55-
from sagemaker.remote_function.core.stored_function import StoredFunction
56-
from sagemaker.remote_function.core.pipeline_variables import (
57-
Context,
58-
convert_pipeline_variables_to_pickleable,
59-
)
55+
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
56+
from sagemaker.remote_function.core.pipeline_variables import Context
6057
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
6158
RuntimeEnvironmentManager,
6259
_DependencySettings,
@@ -695,6 +692,7 @@ def compile(
695692
func_args: tuple,
696693
func_kwargs: dict,
697694
run_info=None,
695+
serialized_data: _SerializedData = None,
698696
) -> dict:
699697
"""Build the artifacts and generate the training job request."""
700698
from sagemaker.workflow.properties import Properties
@@ -732,12 +730,8 @@ def compile(
732730
func_step_s3_dir=step_compilation_context.pipeline_build_time,
733731
),
734732
)
735-
converted_func_args, converted_func_kwargs = convert_pipeline_variables_to_pickleable(
736-
s3_base_uri=s3_base_uri,
737-
func_args=func_args,
738-
func_kwargs=func_kwargs,
739-
)
740-
stored_function.save(func, *converted_func_args, **converted_func_kwargs)
733+
734+
stored_function.save_pipeline_step_function(serialized_data)
741735

742736
stopping_condition = {
743737
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,

src/sagemaker/workflow/function_step.py

+21
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def __init__(
8383
func_kwargs (dict): keyword arguments of the python function.
8484
**kwargs: Additional arguments to be passed to the `step` decorator.
8585
"""
86+
from sagemaker.remote_function.core.pipeline_variables import (
87+
convert_pipeline_variables_to_pickleable,
88+
)
89+
from sagemaker.remote_function.core.serialization import CloudpickleSerializer
90+
from sagemaker.remote_function.core.stored_function import _SerializedData
8691

8792
super(_FunctionStep, self).__init__(
8893
name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies
@@ -96,6 +101,21 @@ def __init__(
96101

97102
self.__job_settings = None
98103

104+
(
105+
self._converted_func_args,
106+
self._converted_func_kwargs,
107+
) = convert_pipeline_variables_to_pickleable(
108+
func_args=self._func_args,
109+
func_kwargs=self._func_kwargs,
110+
)
111+
112+
self._serialized_data = _SerializedData(
113+
func=CloudpickleSerializer.serialize(self._func),
114+
args=CloudpickleSerializer.serialize(
115+
(self._converted_func_args, self._converted_func_kwargs)
116+
),
117+
)
118+
99119
@property
100120
def func(self):
101121
"""The python function to run as a pipeline step."""
@@ -185,6 +205,7 @@ def arguments(self) -> RequestType:
185205
func=self.func,
186206
func_args=self.func_args,
187207
func_kwargs=self.func_kwargs,
208+
serialized_data=self._serialized_data,
188209
)
189210
# Continue to pop job name if not explicitly opted-in via config
190211
request_dict = trim_request_dict(request_dict, "TrainingJobName", step_compilation_context)

tests/integ/sagemaker/workflow/test_step_decorator.py

+58
Original file line numberDiff line numberDiff line change
@@ -858,3 +858,61 @@ def cuberoot(x):
858858
pipeline.delete()
859859
except Exception:
860860
pass
861+
862+
863+
def test_step_level_serialization(
864+
sagemaker_session, role, pipeline_name, region_name, dummy_container_without_error
865+
):
866+
os.environ["AWS_DEFAULT_REGION"] = region_name
867+
868+
_EXPECTED_STEP_A_OUTPUT = "This pipeline is a function."
869+
_EXPECTED_STEP_B_OUTPUT = "This generates a function arg."
870+
871+
step_config = dict(
872+
role=role,
873+
image_uri=dummy_container_without_error,
874+
instance_type=INSTANCE_TYPE,
875+
)
876+
877+
# This pipeline function may clash with the pipeline object
878+
# defined below.
879+
# However, if the function and args serialization happen in
880+
# step level, this clash won't happen.
881+
def pipeline():
882+
return _EXPECTED_STEP_A_OUTPUT
883+
884+
@step(**step_config)
885+
def generator():
886+
return _EXPECTED_STEP_B_OUTPUT
887+
888+
@step(**step_config)
889+
def func_with_collision(var: str):
890+
return f"{pipeline()} {var}"
891+
892+
step_output_a = generator()
893+
step_output_b = func_with_collision(step_output_a)
894+
895+
pipeline = Pipeline( # noqa: F811
896+
name=pipeline_name,
897+
steps=[step_output_b],
898+
sagemaker_session=sagemaker_session,
899+
)
900+
901+
try:
902+
create_and_execute_pipeline(
903+
pipeline=pipeline,
904+
pipeline_name=pipeline_name,
905+
region_name=region_name,
906+
role=role,
907+
no_of_steps=2,
908+
last_step_name=get_step(step_output_b).name,
909+
execution_parameters=dict(),
910+
step_status="Succeeded",
911+
step_result_type=str,
912+
step_result_value=f"{_EXPECTED_STEP_A_OUTPUT} {_EXPECTED_STEP_B_OUTPUT}",
913+
)
914+
finally:
915+
try:
916+
pipeline.delete()
917+
except Exception:
918+
pass

0 commit comments

Comments
 (0)