17
17
import json
18
18
import os
19
19
import sys
20
+ import pickle
21
+ from enum import Enum
20
22
21
23
import cloudpickle
22
24
23
25
from typing import Any , Callable
26
+
27
+ from sagemaker .s3 import s3_path_join
28
+
24
29
from sagemaker .remote_function .errors import ServiceError , SerializationError , DeserializationError
25
30
from sagemaker .s3 import S3Downloader , S3Uploader
26
31
from tblib import pickling_support
27
32
33
+ METADATA_FILE = "metadata.json"
34
+ PAYLOAD_FILE = "payload.pkl"
35
+ HEADER_FILE = "headers.pkl"
36
+ FRAME_FILE = "frame-{}.dat"
37
+
28
38
29
39
def _get_python_version ():
30
40
return f"{ sys .version_info .major } .{ sys .version_info .minor } .{ sys .version_info .micro } "
31
41
32
42
43
+ class SerializationModule (str , Enum ):
44
+ """Represents various serialization modules used."""
45
+
46
+ CLOUDPICKLE = "cloudpickle"
47
+ DASK = "dask"
48
+
49
+
33
50
@dataclasses .dataclass
34
51
class _MetaData :
35
52
"""Metadata about the serialized data or functions."""
36
53
54
+ serialization_module : SerializationModule
37
55
version : str = "2023-04-24"
38
56
python_version : str = _get_python_version ()
39
- serialization_module : str = "cloudpickle"
40
57
41
58
def to_json (self ):
42
59
return json .dumps (dataclasses .asdict (self )).encode ()
@@ -45,16 +62,13 @@ def to_json(self):
45
62
def from_json (s ):
46
63
try :
47
64
obj = json .loads (s )
48
- except json .decoder .JSONDecodeError :
65
+ metadata = _MetaData (** obj )
66
+ except (json .decoder .JSONDecodeError , TypeError ):
49
67
raise DeserializationError ("Corrupt metadata file. It is not a valid json file." )
50
68
51
- metadata = _MetaData ()
52
- metadata .version = obj .get ("version" )
53
- metadata .python_version = obj .get ("python_version" )
54
- metadata .serialization_module = obj .get ("serialization_module" )
55
-
56
- if not (
57
- metadata .version == "2023-04-24" and metadata .serialization_module == "cloudpickle"
69
+ if (
70
+ metadata .version != "2023-04-24"
71
+ or metadata .serialization_module not in SerializationModule .__members__ .values ()
58
72
):
59
73
raise DeserializationError (
60
74
f"Corrupt metadata file. Serialization approach { s } is not supported."
@@ -79,6 +93,12 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
79
93
Raises:
80
94
SerializationError: when fail to serialize object to bytes.
81
95
"""
96
+ _upload_bytes_to_s3 (
97
+ _MetaData (SerializationModule .CLOUDPICKLE ).to_json (),
98
+ os .path .join (s3_uri , METADATA_FILE ),
99
+ s3_kms_key ,
100
+ sagemaker_session ,
101
+ )
82
102
try :
83
103
bytes_to_upload = cloudpickle .dumps (obj )
84
104
except Exception as e :
@@ -95,7 +115,76 @@ def serialize(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: str = None):
95
115
"Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
96
116
) from e
97
117
98
- _upload_bytes_to_s3 (bytes_to_upload , s3_uri , s3_kms_key , sagemaker_session )
118
+ _upload_bytes_to_s3 (
119
+ bytes_to_upload , os .path .join (s3_uri , PAYLOAD_FILE ), s3_kms_key , sagemaker_session
120
+ )
121
+
122
+ @staticmethod
123
+ def deserialize (sagemaker_session , s3_uri ) -> Any :
124
+ """Downloads from S3 and then deserializes data objects.
125
+
126
+ Args:
127
+ sagemaker_session (sagemaker.session.Session): The underlying sagemaker session which
128
+ AWS service calls are delegated to.
129
+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
130
+ Returns :
131
+ List of deserialized python objects.
132
+ Raises:
133
+ DeserializationError: when fail to serialize object to bytes.
134
+ """
135
+ bytes_to_deserialize = _read_bytes_from_s3 (
136
+ os .path .join (s3_uri , PAYLOAD_FILE ), sagemaker_session
137
+ )
138
+
139
+ try :
140
+ return cloudpickle .loads (bytes_to_deserialize )
141
+ except Exception as e :
142
+ raise DeserializationError (
143
+ "Error when deserializing bytes downloaded from {}: {}" .format (
144
+ os .path .join (s3_uri , PAYLOAD_FILE ), repr (e )
145
+ )
146
+ ) from e
147
+
148
+
149
+ class DaskSerializer :
150
+ """Serializer using Dask."""
151
+
152
+ @staticmethod
153
+ def serialize (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
154
+ """Serializes data object and uploads it to S3.
155
+
156
+ Args:
157
+ sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS
158
+ service calls are delegated to.
159
+ s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
160
+ s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
161
+ obj: object to be serialized and persisted
162
+ Raises:
163
+ SerializationError: when fail to serialize object to bytes.
164
+ """
165
+ import distributed .protocol as dask
166
+
167
+ _upload_bytes_to_s3 (
168
+ _MetaData (SerializationModule .DASK ).to_json (),
169
+ os .path .join (s3_uri , METADATA_FILE ),
170
+ s3_kms_key ,
171
+ sagemaker_session ,
172
+ )
173
+ try :
174
+ header , frames = dask .serialize (obj , on_error = "raise" )
175
+ except Exception as e :
176
+ raise SerializationError (
177
+ "Error when serializing object of type [{}]: {}" .format (type (obj ).__name__ , repr (e ))
178
+ ) from e
179
+
180
+ _upload_bytes_to_s3 (
181
+ pickle .dumps (header ), s3_path_join (s3_uri , HEADER_FILE ), s3_kms_key , sagemaker_session
182
+ )
183
+ for idx , frame in enumerate (frames ):
184
+ frame = bytes (frame ) if isinstance (frame , memoryview ) else frame
185
+ _upload_bytes_to_s3 (
186
+ frame , s3_path_join (s3_uri , FRAME_FILE .format (idx )), s3_kms_key , sagemaker_session
187
+ )
99
188
100
189
@staticmethod
101
190
def deserialize (sagemaker_session , s3_uri ) -> Any :
@@ -110,19 +199,29 @@ def deserialize(sagemaker_session, s3_uri) -> Any:
110
199
Raises:
111
200
DeserializationError: when fail to serialize object to bytes.
112
201
"""
113
- bytes_to_deserialize = _read_bytes_from_s3 ( s3_uri , sagemaker_session )
202
+ import distributed . protocol as dask
114
203
204
+ header_to_deserialize = _read_bytes_from_s3 (
205
+ s3_path_join (s3_uri , HEADER_FILE ), sagemaker_session
206
+ )
207
+ headers = pickle .loads (header_to_deserialize )
208
+ num_frames = len (headers ["frame-lengths" ]) if "frame-lengths" in headers else 1
209
+ frames = []
210
+ for idx in range (num_frames ):
211
+ frame = _read_bytes_from_s3 (
212
+ s3_path_join (s3_uri , FRAME_FILE .format (idx )), sagemaker_session
213
+ )
214
+ frames .append (frame )
115
215
try :
116
- return cloudpickle . loads ( bytes_to_deserialize )
216
+ return dask . deserialize ( headers , frames )
117
217
except Exception as e :
118
218
raise DeserializationError (
119
219
"Error when deserializing bytes downloaded from {}: {}" .format (s3_uri , repr (e ))
120
220
) from e
121
221
122
222
123
- # TODO: use dask serializer in case dask distributed is installed in users' environment.
124
223
def serialize_func_to_s3 (func : Callable , sagemaker_session , s3_uri , s3_kms_key = None ):
125
- """Serializes function and uploads it to S3.
224
+ """Serializes function using cloudpickle and uploads it to S3.
126
225
127
226
Args:
128
227
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
@@ -133,13 +232,7 @@ def serialize_func_to_s3(func: Callable, sagemaker_session, s3_uri, s3_kms_key=N
133
232
Raises:
134
233
SerializationError: when fail to serialize function to bytes.
135
234
"""
136
-
137
- _upload_bytes_to_s3 (
138
- _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
139
- )
140
- CloudpickleSerializer .serialize (
141
- func , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
142
- )
235
+ CloudpickleSerializer .serialize (func , sagemaker_session , s3_uri , s3_kms_key )
143
236
144
237
145
238
def deserialize_func_from_s3 (sagemaker_session , s3_uri ) -> Callable :
@@ -157,16 +250,16 @@ def deserialize_func_from_s3(sagemaker_session, s3_uri) -> Callable:
157
250
Raises:
158
251
DeserializationError: when fail to serialize function to bytes.
159
252
"""
160
- _MetaData .from_json (
161
- _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
162
- )
163
-
164
- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
253
+ _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
254
+ return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
165
255
166
256
167
257
def serialize_obj_to_s3 (obj : Any , sagemaker_session , s3_uri : str , s3_kms_key : str = None ):
168
258
"""Serializes data object and uploads it to S3.
169
259
260
+ This method uses the Dask library to perform serialization if its already installed, otherwise,
261
+ it uses cloudpickle.
262
+
170
263
Args:
171
264
sagemaker_session (sagemaker.session.Session): The underlying Boto3 session which AWS service
172
265
calls are delegated to.
@@ -177,12 +270,12 @@ def serialize_obj_to_s3(obj: Any, sagemaker_session, s3_uri: str, s3_kms_key: st
177
270
SerializationError: when fail to serialize object to bytes.
178
271
"""
179
272
180
- _upload_bytes_to_s3 (
181
- _MetaData (). to_json (), os . path . join ( s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
182
- )
183
- CloudpickleSerializer .serialize (
184
- obj , sagemaker_session , os . path . join ( s3_uri , "payload.pkl" ), s3_kms_key
185
- )
273
+ try :
274
+ import distributed . protocol as dask # noqa: F401
275
+
276
+ DaskSerializer .serialize (obj , sagemaker_session , s3_uri , s3_kms_key )
277
+ except ModuleNotFoundError :
278
+ CloudpickleSerializer . serialize ( obj , sagemaker_session , s3_uri , s3_kms_key )
186
279
187
280
188
281
def deserialize_obj_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -197,12 +290,12 @@ def deserialize_obj_from_s3(sagemaker_session, s3_uri) -> Any:
197
290
Raises:
198
291
DeserializationError: when fail to serialize object to bytes.
199
292
"""
200
-
201
- _MetaData .from_json (
202
- _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
293
+ metadata = _MetaData .from_json (
294
+ _read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session )
203
295
)
204
-
205
- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
296
+ if metadata .serialization_module == SerializationModule .DASK :
297
+ return DaskSerializer .deserialize (sagemaker_session , s3_uri )
298
+ return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
206
299
207
300
208
301
def serialize_exception_to_s3 (
@@ -220,12 +313,7 @@ def serialize_exception_to_s3(
220
313
SerializationError: when fail to serialize object to bytes.
221
314
"""
222
315
pickling_support .install ()
223
- _upload_bytes_to_s3 (
224
- _MetaData ().to_json (), os .path .join (s3_uri , "metadata.json" ), s3_kms_key , sagemaker_session
225
- )
226
- CloudpickleSerializer .serialize (
227
- exc , sagemaker_session , os .path .join (s3_uri , "payload.pkl" ), s3_kms_key
228
- )
316
+ CloudpickleSerializer .serialize (exc , sagemaker_session , s3_uri , s3_kms_key )
229
317
230
318
231
319
def deserialize_exception_from_s3 (sagemaker_session , s3_uri ) -> Any :
@@ -240,12 +328,8 @@ def deserialize_exception_from_s3(sagemaker_session, s3_uri) -> Any:
240
328
Raises:
241
329
DeserializationError: when fail to serialize object to bytes.
242
330
"""
243
-
244
- _MetaData .from_json (
245
- _read_bytes_from_s3 (os .path .join (s3_uri , "metadata.json" ), sagemaker_session )
246
- )
247
-
248
- return CloudpickleSerializer .deserialize (sagemaker_session , os .path .join (s3_uri , "payload.pkl" ))
331
+ _MetaData .from_json (_read_bytes_from_s3 (os .path .join (s3_uri , METADATA_FILE ), sagemaker_session ))
332
+ return CloudpickleSerializer .deserialize (sagemaker_session , s3_uri )
249
333
250
334
251
335
def _upload_bytes_to_s3 (bytes , s3_uri , s3_kms_key , sagemaker_session ):
0 commit comments