Skip to content

fix: Move func and args serialization of function step to step level #4312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions src/sagemaker/remote_function/core/pipeline_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ class _ExecutionVariable:
name: str


@dataclass
class _S3BaseUriIdentifier:
"""Identifies that the class refers to function step s3 base uri.

The s3_base_uri = s3_root_uri + pipeline_name.
This identifier is resolved in function step runtime by SDK.
"""

NAME = "S3_BASE_URI"


@dataclass
class _DelayedReturn:
"""Delayed return from a function."""
Expand Down Expand Up @@ -155,6 +166,7 @@ def __init__(
hmac_key: str,
parameter_resolver: _ParameterResolver,
execution_variable_resolver: _ExecutionVariableResolver,
s3_base_uri: str,
**settings,
):
"""Resolve delayed return.
Expand All @@ -164,8 +176,12 @@ def __init__(
hmac_key: key used to encrypt serialized and deserialized function and arguments.
parameter_resolver: resolver used to pipeline parameters.
execution_variable_resolver: resolver used to resolve execution variables.
s3_base_uri (str): the s3 base uri of the function step that
the serialized artifacts will be uploaded to.
The s3_base_uri = s3_root_uri + pipeline_name.
**settings: settings to pass to the deserialization function.
"""
self._s3_base_uri = s3_base_uri
self._parameter_resolver = parameter_resolver
self._execution_variable_resolver = execution_variable_resolver
# different delayed returns can have the same uri, so we need to dedupe
Expand Down Expand Up @@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
uri.append(self._parameter_resolver.resolve(component))
elif isinstance(component, _ExecutionVariable):
uri.append(self._execution_variable_resolver.resolve(component))
elif isinstance(component, _S3BaseUriIdentifier):
uri.append(self._s3_base_uri)
else:
uri.append(component)
return s3_path_join(*uri)
Expand All @@ -219,7 +237,12 @@ def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any):


def resolve_pipeline_variables(
context: Context, func_args: Tuple, func_kwargs: Dict, hmac_key: str, **settings
context: Context,
func_args: Tuple,
func_kwargs: Dict,
hmac_key: str,
s3_base_uri: str,
**settings,
):
"""Resolve pipeline variables.

Expand All @@ -228,6 +251,8 @@ def resolve_pipeline_variables(
func_args: function args.
func_kwargs: function kwargs.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
**settings: settings to pass to the deserialization function.
"""

Expand All @@ -251,6 +276,7 @@ def resolve_pipeline_variables(
hmac_key=hmac_key,
parameter_resolver=parameter_resolver,
execution_variable_resolver=execution_variable_resolver,
s3_base_uri=s3_base_uri,
**settings,
)

Expand Down Expand Up @@ -289,11 +315,10 @@ def resolve_pipeline_variables(
return resolved_func_args, resolved_func_kwargs


def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple, func_kwargs: Dict):
def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict):
"""Convert pipeline variables to pickleable.

Args:
s3_base_uri: s3 base uri where artifacts are stored.
func_args: function args.
func_kwargs: function kwargs.
"""
Expand All @@ -304,11 +329,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,

from sagemaker.workflow.function_step import DelayedReturn

# Notes:
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
# when defining function steps. After step-level arg serialization,
# it's hard to update the s3_base_uri in pipeline compile time.
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
# should be retrieved from the pipeline's sagemaker_session.
def convert(arg):
if isinstance(arg, DelayedReturn):
return _DelayedReturn(
uri=[
s3_base_uri,
_S3BaseUriIdentifier(),
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
arg._step.name,
"results",
Expand Down
59 changes: 37 additions & 22 deletions src/sagemaker/remote_function/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,13 @@ def serialize_func_to_s3(
Raises:
SerializationError: when fail to serialize function to bytes.
"""
bytes_to_upload = CloudpickleSerializer.serialize(func)

_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(func),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


Expand Down Expand Up @@ -220,17 +216,12 @@ def serialize_obj_to_s3(
SerializationError: when fail to serialize object to bytes.
"""

bytes_to_upload = CloudpickleSerializer.serialize(obj)

_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(obj),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


Expand Down Expand Up @@ -318,8 +309,32 @@ def serialize_exception_to_s3(
"""
pickling_support.install()

bytes_to_upload = CloudpickleSerializer.serialize(exc)
_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(exc),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


def _upload_payload_and_metadata_to_s3(
bytes_to_upload: Union[bytes, io.BytesIO],
hmac_key: str,
s3_uri: str,
sagemaker_session: Session,
s3_kms_key,
):
"""Uploads serialized payload and metadata to s3.

Args:
bytes_to_upload (bytes): Serialized bytes to upload.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
sagemaker_session (sagemaker.session.Session):
The underlying Boto3 session which AWS service calls are delegated to.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
"""
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
Expand Down
42 changes: 42 additions & 0 deletions src/sagemaker/remote_function/core/stored_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import os
from dataclasses import dataclass
from typing import Any


Expand All @@ -36,6 +37,14 @@
JSON_RESULTS_FILE = "results.json"


@dataclass
class _SerializedData:
"""Data class to store serialized function and arguments"""

func: bytes
args: bytes


class StoredFunction:
"""Class representing a remote function stored in S3."""

Expand Down Expand Up @@ -105,6 +114,38 @@ def save(self, func, *args, **kwargs):
s3_kms_key=self.s3_kms_key,
)

def save_pipeline_step_function(self, serialized_data):
"""Upload serialized function and arguments to s3.

Args:
serialized_data (_SerializedData): The serialized function
and function arguments of a function step.
"""

logger.info(
"Uploading serialized function code to %s",
s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
)
serialization._upload_payload_and_metadata_to_s3(
bytes_to_upload=serialized_data.func,
hmac_key=self.hmac_key,
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
sagemaker_session=self.sagemaker_session,
s3_kms_key=self.s3_kms_key,
)

logger.info(
"Uploading serialized function arguments to %s",
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
)
serialization._upload_payload_and_metadata_to_s3(
bytes_to_upload=serialized_data.args,
hmac_key=self.hmac_key,
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
sagemaker_session=self.sagemaker_session,
s3_kms_key=self.s3_kms_key,
)

def load_and_invoke(self) -> Any:
"""Load and deserialize the function and the arguments and then execute it."""

Expand Down Expand Up @@ -134,6 +175,7 @@ def load_and_invoke(self) -> Any:
args,
kwargs,
hmac_key=self.hmac_key,
s3_base_uri=self.s3_base_uri,
sagemaker_session=self.sagemaker_session,
)

Expand Down
16 changes: 5 additions & 11 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@
from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config
from sagemaker.s3 import s3_path_join, S3Uploader
from sagemaker import vpc_utils
from sagemaker.remote_function.core.stored_function import StoredFunction
from sagemaker.remote_function.core.pipeline_variables import (
Context,
convert_pipeline_variables_to_pickleable,
)
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
from sagemaker.remote_function.core.pipeline_variables import Context
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
RuntimeEnvironmentManager,
_DependencySettings,
Expand Down Expand Up @@ -695,6 +692,7 @@ def compile(
func_args: tuple,
func_kwargs: dict,
run_info=None,
serialized_data: _SerializedData = None,
) -> dict:
"""Build the artifacts and generate the training job request."""
from sagemaker.workflow.properties import Properties
Expand Down Expand Up @@ -732,12 +730,8 @@ def compile(
func_step_s3_dir=step_compilation_context.pipeline_build_time,
),
)
converted_func_args, converted_func_kwargs = convert_pipeline_variables_to_pickleable(
s3_base_uri=s3_base_uri,
func_args=func_args,
func_kwargs=func_kwargs,
)
stored_function.save(func, *converted_func_args, **converted_func_kwargs)

stored_function.save_pipeline_step_function(serialized_data)

stopping_condition = {
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/workflow/function_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def __init__(
func_kwargs (dict): keyword arguments of the python function.
**kwargs: Additional arguments to be passed to the `step` decorator.
"""
from sagemaker.remote_function.core.pipeline_variables import (
convert_pipeline_variables_to_pickleable,
)
from sagemaker.remote_function.core.serialization import CloudpickleSerializer
from sagemaker.remote_function.core.stored_function import _SerializedData

super(_FunctionStep, self).__init__(
name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies
Expand All @@ -96,6 +101,21 @@ def __init__(

self.__job_settings = None

(
self._converted_func_args,
self._converted_func_kwargs,
) = convert_pipeline_variables_to_pickleable(
func_args=self._func_args,
func_kwargs=self._func_kwargs,
)

self._serialized_data = _SerializedData(
func=CloudpickleSerializer.serialize(self._func),
args=CloudpickleSerializer.serialize(
(self._converted_func_args, self._converted_func_kwargs)
),
)

@property
def func(self):
"""The python function to run as a pipeline step."""
Expand Down Expand Up @@ -185,6 +205,7 @@ def arguments(self) -> RequestType:
func=self.func,
func_args=self.func_args,
func_kwargs=self.func_kwargs,
serialized_data=self._serialized_data,
)
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "TrainingJobName", step_compilation_context)
Expand Down
Loading