Skip to content

Commit 7a8635b

Browse files
pintaoz-awspintaoz
authored and
root
committed
Add backward compatbility for RecordSerializer and RecordDeserializer (aws#5052)
* Add backward compatbility for RecordSerializer and RecordDeserializer * fix circular import * fix test --------- Co-authored-by: pintaoz <[email protected]>
1 parent 122ea28 commit 7a8635b

File tree

4 files changed

+234
-209
lines changed

4 files changed

+234
-209
lines changed

src/sagemaker/amazon/common.py

Lines changed: 10 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -13,210 +13,13 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
import logging
17-
import struct
18-
import sys
19-
20-
import numpy as np
21-
22-
from sagemaker.amazon.record_pb2 import Record
23-
from sagemaker.utils import DeferredError
24-
25-
26-
def _write_feature_tensor(resolved_type, record, vector):
27-
"""Placeholder Docstring"""
28-
if resolved_type == "Int32":
29-
record.features["values"].int32_tensor.values.extend(vector)
30-
elif resolved_type == "Float64":
31-
record.features["values"].float64_tensor.values.extend(vector)
32-
elif resolved_type == "Float32":
33-
record.features["values"].float32_tensor.values.extend(vector)
34-
35-
36-
def _write_label_tensor(resolved_type, record, scalar):
37-
"""Placeholder Docstring"""
38-
if resolved_type == "Int32":
39-
record.label["values"].int32_tensor.values.extend([scalar])
40-
elif resolved_type == "Float64":
41-
record.label["values"].float64_tensor.values.extend([scalar])
42-
elif resolved_type == "Float32":
43-
record.label["values"].float32_tensor.values.extend([scalar])
44-
45-
46-
def _write_keys_tensor(resolved_type, record, vector):
47-
"""Placeholder Docstring"""
48-
if resolved_type == "Int32":
49-
record.features["values"].int32_tensor.keys.extend(vector)
50-
elif resolved_type == "Float64":
51-
record.features["values"].float64_tensor.keys.extend(vector)
52-
elif resolved_type == "Float32":
53-
record.features["values"].float32_tensor.keys.extend(vector)
54-
55-
56-
def _write_shape(resolved_type, record, scalar):
57-
"""Placeholder Docstring"""
58-
if resolved_type == "Int32":
59-
record.features["values"].int32_tensor.shape.extend([scalar])
60-
elif resolved_type == "Float64":
61-
record.features["values"].float64_tensor.shape.extend([scalar])
62-
elif resolved_type == "Float32":
63-
record.features["values"].float32_tensor.shape.extend([scalar])
64-
65-
66-
def write_numpy_to_dense_tensor(file, array, labels=None):
67-
"""Writes a numpy array to a dense tensor
68-
69-
Args:
70-
file:
71-
array:
72-
labels:
73-
"""
74-
75-
# Validate shape of array and labels, resolve array and label types
76-
if not len(array.shape) == 2:
77-
raise ValueError("Array must be a Matrix")
78-
if labels is not None:
79-
if not len(labels.shape) == 1:
80-
raise ValueError("Labels must be a Vector")
81-
if labels.shape[0] not in array.shape:
82-
raise ValueError(
83-
"Label shape {} not compatible with array shape {}".format(
84-
labels.shape, array.shape
85-
)
86-
)
87-
resolved_label_type = _resolve_type(labels.dtype)
88-
resolved_type = _resolve_type(array.dtype)
89-
90-
# Write each vector in array into a Record in the file object
91-
record = Record()
92-
for index, vector in enumerate(array):
93-
record.Clear()
94-
_write_feature_tensor(resolved_type, record, vector)
95-
if labels is not None:
96-
_write_label_tensor(resolved_label_type, record, labels[index])
97-
_write_recordio(file, record.SerializeToString())
98-
99-
100-
def write_spmatrix_to_sparse_tensor(file, array, labels=None):
101-
"""Writes a scipy sparse matrix to a sparse tensor
102-
103-
Args:
104-
file:
105-
array:
106-
labels:
107-
"""
108-
try:
109-
import scipy
110-
except ImportError as e:
111-
logging.warning(
112-
"scipy failed to import. Sparse matrix functions will be impaired or broken."
113-
)
114-
# Any subsequent attempt to use scipy will raise the ImportError
115-
scipy = DeferredError(e)
116-
117-
if not scipy.sparse.issparse(array):
118-
raise TypeError("Array must be sparse")
119-
120-
# Validate shape of array and labels, resolve array and label types
121-
if not len(array.shape) == 2:
122-
raise ValueError("Array must be a Matrix")
123-
if labels is not None:
124-
if not len(labels.shape) == 1:
125-
raise ValueError("Labels must be a Vector")
126-
if labels.shape[0] not in array.shape:
127-
raise ValueError(
128-
"Label shape {} not compatible with array shape {}".format(
129-
labels.shape, array.shape
130-
)
131-
)
132-
resolved_label_type = _resolve_type(labels.dtype)
133-
resolved_type = _resolve_type(array.dtype)
134-
135-
csr_array = array.tocsr()
136-
n_rows, n_cols = csr_array.shape
137-
138-
record = Record()
139-
for row_idx in range(n_rows):
140-
record.Clear()
141-
row = csr_array.getrow(row_idx)
142-
# Write values
143-
_write_feature_tensor(resolved_type, record, row.data)
144-
# Write keys
145-
_write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64))
146-
147-
# Write labels
148-
if labels is not None:
149-
_write_label_tensor(resolved_label_type, record, labels[row_idx])
150-
151-
# Write shape
152-
_write_shape(resolved_type, record, n_cols)
153-
154-
_write_recordio(file, record.SerializeToString())
155-
156-
157-
def read_records(file):
158-
"""Eagerly read a collection of amazon Record protobuf objects from file.
159-
160-
Args:
161-
file:
162-
"""
163-
records = []
164-
for record_data in read_recordio(file):
165-
record = Record()
166-
record.ParseFromString(record_data)
167-
records.append(record)
168-
return records
169-
170-
171-
# MXNet requires recordio records have length in bytes that's a multiple of 4
172-
# This sets up padding bytes to append to the end of the record, for diferent
173-
# amounts of padding required.
174-
padding = {}
175-
for amount in range(4):
176-
if sys.version_info >= (3,):
177-
padding[amount] = bytes([0x00 for _ in range(amount)])
178-
else:
179-
padding[amount] = bytearray([0x00 for _ in range(amount)])
180-
181-
_kmagic = 0xCED7230A
182-
183-
184-
def _write_recordio(f, data):
185-
"""Writes a single data point as a RecordIO record to the given file.
186-
187-
Args:
188-
f:
189-
data:
190-
"""
191-
length = len(data)
192-
f.write(struct.pack("I", _kmagic))
193-
f.write(struct.pack("I", length))
194-
pad = (((length + 3) >> 2) << 2) - length
195-
f.write(data)
196-
f.write(padding[pad])
197-
198-
199-
def read_recordio(f):
200-
"""Placeholder Docstring"""
201-
while True:
202-
try:
203-
(read_kmagic,) = struct.unpack("I", f.read(4))
204-
except struct.error:
205-
return
206-
assert read_kmagic == _kmagic
207-
(len_record,) = struct.unpack("I", f.read(4))
208-
pad = (((len_record + 3) >> 2) << 2) - len_record
209-
yield f.read(len_record)
210-
if pad:
211-
f.read(pad)
212-
213-
214-
def _resolve_type(dtype):
215-
"""Placeholder Docstring"""
216-
if dtype == np.dtype(int):
217-
return "Int32"
218-
if dtype == np.dtype(float):
219-
return "Float64"
220-
if dtype == np.dtype("float32"):
221-
return "Float32"
222-
raise ValueError("Unsupported dtype {} on array".format(dtype))
16+
# these imports ensure backward compatibility.
17+
from sagemaker.deserializers import RecordDeserializer # noqa: F401 # pylint: disable=W0611
18+
from sagemaker.serializers import RecordSerializer # noqa: F401 # pylint: disable=W0611
19+
from sagemaker.serializer_utils import ( # noqa: F401 # pylint: disable=W0611
20+
read_recordio,
21+
read_records,
22+
write_numpy_to_dense_tensor,
23+
write_spmatrix_to_sparse_tensor,
24+
_write_recordio,
25+
)

src/sagemaker/base_deserializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
from six import with_metaclass
2525

26-
from sagemaker.amazon.common import read_records
26+
from sagemaker.serializer_utils import read_records
2727
from sagemaker.utils import DeferredError
2828

2929
try:

src/sagemaker/base_serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pandas import DataFrame
2323
from six import with_metaclass
2424

25-
from sagemaker.amazon.common import write_numpy_to_dense_tensor
25+
from sagemaker.serializer_utils import write_numpy_to_dense_tensor
2626
from sagemaker.utils import DeferredError
2727

2828
try:

0 commit comments

Comments
 (0)