|
13 | 13 | """Placeholder docstring"""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 |
| -import logging |
17 |
| -import struct |
18 |
| -import sys |
19 |
| - |
20 |
| -import numpy as np |
21 |
| - |
22 |
| -from sagemaker.amazon.record_pb2 import Record |
23 |
| -from sagemaker.utils import DeferredError |
24 |
| - |
25 |
| - |
26 |
| -def _write_feature_tensor(resolved_type, record, vector): |
27 |
| - """Placeholder Docstring""" |
28 |
| - if resolved_type == "Int32": |
29 |
| - record.features["values"].int32_tensor.values.extend(vector) |
30 |
| - elif resolved_type == "Float64": |
31 |
| - record.features["values"].float64_tensor.values.extend(vector) |
32 |
| - elif resolved_type == "Float32": |
33 |
| - record.features["values"].float32_tensor.values.extend(vector) |
34 |
| - |
35 |
| - |
36 |
| -def _write_label_tensor(resolved_type, record, scalar): |
37 |
| - """Placeholder Docstring""" |
38 |
| - if resolved_type == "Int32": |
39 |
| - record.label["values"].int32_tensor.values.extend([scalar]) |
40 |
| - elif resolved_type == "Float64": |
41 |
| - record.label["values"].float64_tensor.values.extend([scalar]) |
42 |
| - elif resolved_type == "Float32": |
43 |
| - record.label["values"].float32_tensor.values.extend([scalar]) |
44 |
| - |
45 |
| - |
46 |
| -def _write_keys_tensor(resolved_type, record, vector): |
47 |
| - """Placeholder Docstring""" |
48 |
| - if resolved_type == "Int32": |
49 |
| - record.features["values"].int32_tensor.keys.extend(vector) |
50 |
| - elif resolved_type == "Float64": |
51 |
| - record.features["values"].float64_tensor.keys.extend(vector) |
52 |
| - elif resolved_type == "Float32": |
53 |
| - record.features["values"].float32_tensor.keys.extend(vector) |
54 |
| - |
55 |
| - |
56 |
| -def _write_shape(resolved_type, record, scalar): |
57 |
| - """Placeholder Docstring""" |
58 |
| - if resolved_type == "Int32": |
59 |
| - record.features["values"].int32_tensor.shape.extend([scalar]) |
60 |
| - elif resolved_type == "Float64": |
61 |
| - record.features["values"].float64_tensor.shape.extend([scalar]) |
62 |
| - elif resolved_type == "Float32": |
63 |
| - record.features["values"].float32_tensor.shape.extend([scalar]) |
64 |
| - |
65 |
| - |
66 |
| -def write_numpy_to_dense_tensor(file, array, labels=None): |
67 |
| - """Writes a numpy array to a dense tensor |
68 |
| -
|
69 |
| - Args: |
70 |
| - file: |
71 |
| - array: |
72 |
| - labels: |
73 |
| - """ |
74 |
| - |
75 |
| - # Validate shape of array and labels, resolve array and label types |
76 |
| - if not len(array.shape) == 2: |
77 |
| - raise ValueError("Array must be a Matrix") |
78 |
| - if labels is not None: |
79 |
| - if not len(labels.shape) == 1: |
80 |
| - raise ValueError("Labels must be a Vector") |
81 |
| - if labels.shape[0] not in array.shape: |
82 |
| - raise ValueError( |
83 |
| - "Label shape {} not compatible with array shape {}".format( |
84 |
| - labels.shape, array.shape |
85 |
| - ) |
86 |
| - ) |
87 |
| - resolved_label_type = _resolve_type(labels.dtype) |
88 |
| - resolved_type = _resolve_type(array.dtype) |
89 |
| - |
90 |
| - # Write each vector in array into a Record in the file object |
91 |
| - record = Record() |
92 |
| - for index, vector in enumerate(array): |
93 |
| - record.Clear() |
94 |
| - _write_feature_tensor(resolved_type, record, vector) |
95 |
| - if labels is not None: |
96 |
| - _write_label_tensor(resolved_label_type, record, labels[index]) |
97 |
| - _write_recordio(file, record.SerializeToString()) |
98 |
| - |
99 |
| - |
100 |
| -def write_spmatrix_to_sparse_tensor(file, array, labels=None): |
101 |
| - """Writes a scipy sparse matrix to a sparse tensor |
102 |
| -
|
103 |
| - Args: |
104 |
| - file: |
105 |
| - array: |
106 |
| - labels: |
107 |
| - """ |
108 |
| - try: |
109 |
| - import scipy |
110 |
| - except ImportError as e: |
111 |
| - logging.warning( |
112 |
| - "scipy failed to import. Sparse matrix functions will be impaired or broken." |
113 |
| - ) |
114 |
| - # Any subsequent attempt to use scipy will raise the ImportError |
115 |
| - scipy = DeferredError(e) |
116 |
| - |
117 |
| - if not scipy.sparse.issparse(array): |
118 |
| - raise TypeError("Array must be sparse") |
119 |
| - |
120 |
| - # Validate shape of array and labels, resolve array and label types |
121 |
| - if not len(array.shape) == 2: |
122 |
| - raise ValueError("Array must be a Matrix") |
123 |
| - if labels is not None: |
124 |
| - if not len(labels.shape) == 1: |
125 |
| - raise ValueError("Labels must be a Vector") |
126 |
| - if labels.shape[0] not in array.shape: |
127 |
| - raise ValueError( |
128 |
| - "Label shape {} not compatible with array shape {}".format( |
129 |
| - labels.shape, array.shape |
130 |
| - ) |
131 |
| - ) |
132 |
| - resolved_label_type = _resolve_type(labels.dtype) |
133 |
| - resolved_type = _resolve_type(array.dtype) |
134 |
| - |
135 |
| - csr_array = array.tocsr() |
136 |
| - n_rows, n_cols = csr_array.shape |
137 |
| - |
138 |
| - record = Record() |
139 |
| - for row_idx in range(n_rows): |
140 |
| - record.Clear() |
141 |
| - row = csr_array.getrow(row_idx) |
142 |
| - # Write values |
143 |
| - _write_feature_tensor(resolved_type, record, row.data) |
144 |
| - # Write keys |
145 |
| - _write_keys_tensor(resolved_type, record, row.indices.astype(np.uint64)) |
146 |
| - |
147 |
| - # Write labels |
148 |
| - if labels is not None: |
149 |
| - _write_label_tensor(resolved_label_type, record, labels[row_idx]) |
150 |
| - |
151 |
| - # Write shape |
152 |
| - _write_shape(resolved_type, record, n_cols) |
153 |
| - |
154 |
| - _write_recordio(file, record.SerializeToString()) |
155 |
| - |
156 |
| - |
157 |
| -def read_records(file): |
158 |
| - """Eagerly read a collection of amazon Record protobuf objects from file. |
159 |
| -
|
160 |
| - Args: |
161 |
| - file: |
162 |
| - """ |
163 |
| - records = [] |
164 |
| - for record_data in read_recordio(file): |
165 |
| - record = Record() |
166 |
| - record.ParseFromString(record_data) |
167 |
| - records.append(record) |
168 |
| - return records |
169 |
| - |
170 |
| - |
171 |
| -# MXNet requires recordio records have length in bytes that's a multiple of 4 |
172 |
| -# This sets up padding bytes to append to the end of the record, for diferent |
173 |
| -# amounts of padding required. |
174 |
| -padding = {} |
175 |
| -for amount in range(4): |
176 |
| - if sys.version_info >= (3,): |
177 |
| - padding[amount] = bytes([0x00 for _ in range(amount)]) |
178 |
| - else: |
179 |
| - padding[amount] = bytearray([0x00 for _ in range(amount)]) |
180 |
| - |
181 |
| -_kmagic = 0xCED7230A |
182 |
| - |
183 |
| - |
184 |
| -def _write_recordio(f, data): |
185 |
| - """Writes a single data point as a RecordIO record to the given file. |
186 |
| -
|
187 |
| - Args: |
188 |
| - f: |
189 |
| - data: |
190 |
| - """ |
191 |
| - length = len(data) |
192 |
| - f.write(struct.pack("I", _kmagic)) |
193 |
| - f.write(struct.pack("I", length)) |
194 |
| - pad = (((length + 3) >> 2) << 2) - length |
195 |
| - f.write(data) |
196 |
| - f.write(padding[pad]) |
197 |
| - |
198 |
| - |
199 |
| -def read_recordio(f): |
200 |
| - """Placeholder Docstring""" |
201 |
| - while True: |
202 |
| - try: |
203 |
| - (read_kmagic,) = struct.unpack("I", f.read(4)) |
204 |
| - except struct.error: |
205 |
| - return |
206 |
| - assert read_kmagic == _kmagic |
207 |
| - (len_record,) = struct.unpack("I", f.read(4)) |
208 |
| - pad = (((len_record + 3) >> 2) << 2) - len_record |
209 |
| - yield f.read(len_record) |
210 |
| - if pad: |
211 |
| - f.read(pad) |
212 |
| - |
213 |
| - |
214 |
| -def _resolve_type(dtype): |
215 |
| - """Placeholder Docstring""" |
216 |
| - if dtype == np.dtype(int): |
217 |
| - return "Int32" |
218 |
| - if dtype == np.dtype(float): |
219 |
| - return "Float64" |
220 |
| - if dtype == np.dtype("float32"): |
221 |
| - return "Float32" |
222 |
| - raise ValueError("Unsupported dtype {} on array".format(dtype)) |
| 16 | +# these imports ensure backward compatibility. |
| 17 | +from sagemaker.deserializers import RecordDeserializer # noqa: F401 # pylint: disable=W0611 |
| 18 | +from sagemaker.serializers import RecordSerializer # noqa: F401 # pylint: disable=W0611 |
| 19 | +from sagemaker.serializer_utils import ( |
| 20 | + write_numpy_to_dense_tensor, |
| 21 | + write_spmatrix_to_sparse_tensor, |
| 22 | +) # noqa: F401 # pylint: disable=W0611 |
0 commit comments