Skip to content

Commit 06f72ec

Browse files
jncltyangaws
authored andcommitted
Make record serializer accept one-dimensional arrays (#320) (#334)
1 parent 7baf682 commit 06f72ec

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.9.1dev
6+
========
7+
8+
* bug-fix: Estimators: Fix serialization of single records
9+
510
1.9.0
611
=====
712

src/sagemaker/amazon/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, content_type='application/x-recordio-protobuf'):
2929

3030
def __call__(self, array):
3131
if len(array.shape) == 1:
32-
array.reshape(1, array.shape[0])
32+
array = array.reshape(1, array.shape[0])
3333
assert len(array.shape) == 2, "Expecting a 1 or 2 dimensional array"
3434
buf = io.BytesIO()
3535
write_numpy_to_dense_tensor(buf, array)

tests/unit/test_common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ def test_serializer():
3232
assert record.features["values"].float64_tensor.values == expected
3333

3434

35+
def test_serializer_accepts_one_dimensional_array():
36+
s = numpy_to_record_serializer()
37+
array_data = [1.0, 2.0, 3.0]
38+
buf = s(np.array(array_data))
39+
record_data = next(_read_recordio(buf))
40+
record = Record()
41+
record.ParseFromString(record_data)
42+
assert record.features["values"].float64_tensor.values == array_data
43+
44+
3545
def test_deserializer():
3646
array_data = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
3747
s = numpy_to_record_serializer()

0 commit comments

Comments
 (0)