Skip to content

Commit 4ac09d7

Browse files
nmadanNamrata Madan
and
Namrata Madan
committed
feature: serialize objs using dask (aws#907)
Co-authored-by: Namrata Madan <[email protected]>
1 parent 56b1291 commit 4ac09d7

File tree

9 files changed

+708
-267
lines changed

9 files changed

+708
-267
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
cloudpickle==2.2.0
1+
cloudpickle==2.2.1
22
tblib==1.7.0

requirements/extras/test_requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
2323
scikit-learn==1.0.2
24-
cloudpickle==2.2.0
24+
cloudpickle==2.2.1
25+
distributed==2022.2.0

src/sagemaker/remote_function/core/serialization.py

+133-49
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,43 @@
1717
import json
1818
import os
1919
import sys
20+
import pickle
21+
from enum import Enum
2022

2123
import cloudpickle
2224

2325
from typing import Any, Callable
26+
27+
from sagemaker.s3 import s3_path_join
28+
2429
from sagemaker.remote_function.errors import ServiceError, SerializationError, DeserializationError
2530
from sagemaker.s3 import S3Downloader, S3Uploader
2631
from tblib import pickling_support
2732

33+
METADATA_FILE = "metadata.json"
34+
PAYLOAD_FILE = "payload.pkl"
35+
HEADER_FILE = "headers.pkl"
36+
FRAME_FILE = "frame-{}.dat"
37+
2838

2939
def _get_python_version():
3040
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
3141

3242

43+
class SerializationModule(str, Enum):
44+
"""Represents various serialization modules used."""
45+
46+
CLOUDPICKLE = "cloudpickle"
47+
DASK = "dask"
48+
49+
3350
@dataclasses.dataclass
3451
class _MetaData:
3552
"""Metadata about the serialized data or functions."""
3653

54+
serialization_module: SerializationModule
3755
version: str = "2023-04-24"
3856
python_version: str = _get_python_version()
39-
serialization_module: str = "cloudpickle"
4057

4158
def to_json(self):
4259
return json.dumps(dataclasses.asdict(self)).encode()
@@ -45,16 +62,13 @@ def to_json(self):
4562
def from_json(s):
4663
try:
4764
obj = json.loads(s)
48-
except json.decoder.JSONDecodeError:
65+
metadata = _MetaData(**obj)
66+
except (json.decoder.JSONDecodeError, TypeError):
4967
raise DeserializationError("Corrupt metadata file. It is not a valid json file.")
5068

51-
metadata = _MetaData()
52-
metadata.version = obj.get("version")
53-
metadata.python_version = obj.get("python_version")
54-
metadata.serialization_module = obj.get("serialization_module")
55-
56-
if not (
57-
metadata.version == "2023-04-24" and metadata.serialization_module == "cloudpickle"
69+
if (
70+
metadata.version != "2023-04-24"
71+
or metadata.serialization_module not in SerializationModule.__members__.values()
5872
):
5973
raise DeserializationError(
6074
f"Corrupt metadata file. Serialization approach {s} is not supported."
@@ -79,6 +93,12 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
7993
Raises:
8094
SerializationError: when fail to serialize object to bytes.
8195
"""
96+
_upload_bytes_to_s3(
97+
_MetaData(SerializationModule.CLOUDPICKLE).to_json(),
98+
os.path.join(s3_uri, METADATA_FILE),
99+
s3_kms_key,
100+
sagemaker_session,
101+
)
82102
try:
83103
bytes_to_upload = cloudpickle.dumps(obj)
84104
except Exception as e:
@@ -95,7 +115,76 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
95115
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
96116
) from e
97117

98-
_upload_bytes_to_s3(bytes_to_upload, s3_uri, s3_kms_key, sagemaker_session)
118+
_upload_bytes_to_s3(
119+
bytes_to_upload, os.path.join(s3_uri, PAYLOAD_FILE), s3_kms_key, sagemaker_session
120+
)
121+
122+
@staticmethod
123+
def deserialize(sagemaker_session, s3_uri) -> Any:
124+
"""Downloads from S3 and then deserializes data objects.
125+
126+
Args:
127+
sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which
128+
AWS service calls are delegated to.
129+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
130+
Returns :
131+
List of deserialized python objects.
132+
Raises:
133+
DeserializationError: when fail to serialize object to bytes.
134+
"""
135+
bytes_to_deserialize = _read_bytes_from_s3(
136+
os.path.join(s3_uri, PAYLOAD_FILE), sagemaker_session
137+
)
138+
139+
try:
140+
return cloudpickle.loads(bytes_to_deserialize)
141+
except Exception as e:
142+
raise DeserializationError(
143+
"Error when deserializing bytes downloaded from {}: {}".format(
144+
os.path.join(s3_uri, PAYLOAD_FILE), repr(e)
145+
)
146+
) from e
147+
148+
149+
class DaskSerializer:
150+
"""Serializer using Dask."""
151+
152+
@staticmethod
153+
def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
154+
"""Serializes data object and uploads it to S3.
155+
156+
Args:
157+
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS
158+
service calls are delegated to.
159+
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
160+
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
161+
obj: object to be serialized and persisted
162+
Raises:
163+
SerializationError: when fail to serialize object to bytes.
164+
"""
165+
import distributed.protocol as dask
166+
167+
_upload_bytes_to_s3(
168+
_MetaData(SerializationModule.DASK).to_json(),
169+
os.path.join(s3_uri, METADATA_FILE),
170+
s3_kms_key,
171+
sagemaker_session,
172+
)
173+
try:
174+
header, frames = dask.serialize(obj, on_error="raise")
175+
except Exception as e:
176+
raise SerializationError(
177+
"Error when serializing object of type [{}]: {}".format(type(obj).__name__, repr(e))
178+
) from e
179+
180+
_upload_bytes_to_s3(
181+
pickle.dumps(header), s3_path_join(s3_uri, HEADER_FILE), s3_kms_key, sagemaker_session
182+
)
183+
for idx, frame in enumerate(frames):
184+
frame = bytes(frame) if isinstance(frame, memoryview) else frame
185+
_upload_bytes_to_s3(
186+
frame, s3_path_join(s3_uri, FRAME_FILE.format(idx)), s3_kms_key, sagemaker_session
187+
)
99188

100189
@staticmethod
101190
def deserialize(sagemaker_session, s3_uri) -> Any:
@@ -110,19 +199,29 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
110199
Raises:
111200
DeserializationError: when fail to serialize object to bytes.
112201
"""
113-
bytes_to_deserialize = _read_bytes_from_s3(s3_uri, sagemaker_session)
202+
import distributed.protocol as dask
114203

204+
header_to_deserialize = _read_bytes_from_s3(
205+
s3_path_join(s3_uri, HEADER_FILE), sagemaker_session
206+
)
207+
headers = pickle.loads(header_to_deserialize)
208+
num_frames = len(headers["frame-lengths"]) if "frame-lengths" in headers else 1
209+
frames = []
210+
for idx in range(num_frames):
211+
frame = _read_bytes_from_s3(
212+
s3_path_join(s3_uri, FRAME_FILE.format(idx)), sagemaker_session
213+
)
214+
frames.append(frame)
115215
try:
116-
return cloudpickle.loads(bytes_to_deserialize)
216+
return dask.deserialize(headers, frames)
117217
except Exception as e:
118218
raise DeserializationError(
119219
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
120220
) from e
121221

122222

123-
# TODO: use dask serializer in case dask distributed is installed in users' environment.
124223
def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=None):
125-
"""Serializes function and uploads it to S3.
224+
"""Serializes function using cloudpickle and uploads it to S3.
126225
127226
Args:
128227
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
@@ -133,13 +232,7 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
133232
Raises:
134233
SerializationError: when fail to serialize function to bytes.
135234
"""
136-
137-
_upload_bytes_to_s3(
138-
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
139-
)
140-
CloudpickleSerializer.serialize(
141-
func, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
142-
)
235+
CloudpickleSerializer.serialize(func, sagemaker_session, s3_uri, s3_kms_key)
143236

144237

145238
def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
@@ -157,16 +250,16 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
157250
Raises:
158251
DeserializationError: when fail to serialize function to bytes.
159252
"""
160-
_MetaData.from_json(
161-
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
162-
)
163-
164-
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
253+
_MetaData.from_json(_read_bytes_from_s3(os.path.join(s3_uri, METADATA_FILE), sagemaker_session))
254+
return CloudpickleSerializer.deserialize(sagemaker_session, s3_uri)
165255

166256

167257
def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
168258
"""Serializes data object and uploads it to S3.
169259
260+
This method uses the Dask library to perform serialization if its already installed, otherwise,
261+
it uses cloudpickle.
262+
170263
Args:
171264
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
172265
calls are delegated to.
@@ -177,12 +270,12 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
177270
SerializationError: when fail to serialize object to bytes.
178271
"""
179272

180-
_upload_bytes_to_s3(
181-
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
182-
)
183-
CloudpickleSerializer.serialize(
184-
obj, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
185-
)
273+
try:
274+
import distributed.protocol as dask # noqa: F401
275+
276+
DaskSerializer.serialize(obj, sagemaker_session, s3_uri, s3_kms_key)
277+
except ModuleNotFoundError:
278+
CloudpickleSerializer.serialize(obj, sagemaker_session, s3_uri, s3_kms_key)
186279

187280

188281
def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
@@ -197,12 +290,12 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
197290
Raises:
198291
DeserializationError: when fail to serialize object to bytes.
199292
"""
200-
201-
_MetaData.from_json(
202-
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
293+
metadata = _MetaData.from_json(
294+
_read_bytes_from_s3(os.path.join(s3_uri, METADATA_FILE), sagemaker_session)
203295
)
204-
205-
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
296+
if metadata.serialization_module == SerializationModule.DASK:
297+
return DaskSerializer.deserialize(sagemaker_session, s3_uri)
298+
return CloudpickleSerializer.deserialize(sagemaker_session, s3_uri)
206299

207300

208301
def serialize_exception_to_s3(
@@ -220,12 +313,7 @@ def serialize_exception_to_s3(
220313
SerializationError: when fail to serialize object to bytes.
221314
"""
222315
pickling_support.install()
223-
_upload_bytes_to_s3(
224-
_MetaData().to_json(), os.path.join(s3_uri, "metadata.json"), s3_kms_key, sagemaker_session
225-
)
226-
CloudpickleSerializer.serialize(
227-
exc, sagemaker_session, os.path.join(s3_uri, "payload.pkl"), s3_kms_key
228-
)
316+
CloudpickleSerializer.serialize(exc, sagemaker_session, s3_uri, s3_kms_key)
229317

230318

231319
def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
@@ -240,12 +328,8 @@ def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
240328
Raises:
241329
DeserializationError: when fail to serialize object to bytes.
242330
"""
243-
244-
_MetaData.from_json(
245-
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
246-
)
247-
248-
return CloudpickleSerializer.deserialize(sagemaker_session, os.path.join(s3_uri, "payload.pkl"))
331+
_MetaData.from_json(_read_bytes_from_s3(os.path.join(s3_uri, METADATA_FILE), sagemaker_session))
332+
return CloudpickleSerializer.deserialize(sagemaker_session, s3_uri)
249333

250334

251335
def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):

src/sagemaker/remote_function/core/stored_function.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import sagemaker.remote_function.core.serialization as serialization
2020

21-
2221
logger = logging_config.get_logger()
2322

2423

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
distributed==2022.2.0

0 commit comments

Comments
 (0)