|
14 | 14 |
|
15 | 15 | import io
|
16 | 16 |
|
17 |
| -from sagemaker.deserializers import StringDeserializer, BytesDeserializer, StreamDeserializer |
| 17 | +import numpy as np |
| 18 | +import pytest |
| 19 | + |
| 20 | +from sagemaker.deserializers import ( |
| 21 | + StringDeserializer, |
| 22 | + BytesDeserializer, |
| 23 | + StreamDeserializer, |
| 24 | + NumpyDeserializer, |
| 25 | +) |
18 | 26 |
|
19 | 27 |
|
20 | 28 | def test_string_deserializer():
|
@@ -44,3 +52,70 @@ def test_stream_deserializer():
|
44 | 52 |
|
45 | 53 | assert result == b"[1, 2, 3]"
|
46 | 54 | assert content_type == "application/json"
|
| 55 | + |
| 56 | + |
| 57 | +@pytest.fixture |
| 58 | +def numpy_deserializer(): |
| 59 | + return NumpyDeserializer() |
| 60 | + |
| 61 | + |
| 62 | +def test_numpy_deserializer_from_csv(numpy_deserializer): |
| 63 | + stream = io.BytesIO(b"1,2,3\n4,5,6") |
| 64 | + array = numpy_deserializer.deserialize(stream, "text/csv") |
| 65 | + assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6]])) |
| 66 | + |
| 67 | + |
| 68 | +def test_numpy_deserializer_from_csv_ragged(numpy_deserializer): |
| 69 | + stream = io.BytesIO(b"1,2,3\n4,5,6,7") |
| 70 | + with pytest.raises(ValueError) as error: |
| 71 | + numpy_deserializer.deserialize(stream, "text/csv") |
| 72 | + assert "errors were detected" in str(error) |
| 73 | + |
| 74 | + |
| 75 | +def test_numpy_deserializer_from_csv_alpha(): |
| 76 | + numpy_deserializer = NumpyDeserializer(dtype="U5") |
| 77 | + stream = io.BytesIO(b"hello,2,3\n4,5,6") |
| 78 | + array = numpy_deserializer.deserialize(stream, "text/csv") |
| 79 | + assert np.array_equal(array, np.array([["hello", 2, 3], [4, 5, 6]])) |
| 80 | + |
| 81 | + |
| 82 | +def test_numpy_deserializer_from_json(numpy_deserializer): |
| 83 | + stream = io.BytesIO(b"[[1,2,3],\n[4,5,6]]") |
| 84 | + array = numpy_deserializer.deserialize(stream, "application/json") |
| 85 | + assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6]])) |
| 86 | + |
| 87 | + |
| 88 | +# Sadly, ragged arrays work fine in JSON (giving us a 1D array of Python lists) |
| 89 | +def test_numpy_deserializer_from_json_ragged(numpy_deserializer): |
| 90 | + stream = io.BytesIO(b"[[1,2,3],\n[4,5,6,7]]") |
| 91 | + array = numpy_deserializer.deserialize(stream, "application/json") |
| 92 | + assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6, 7]])) |
| 93 | + |
| 94 | + |
| 95 | +def test_numpy_deserializer_from_json_alpha(): |
| 96 | + numpy_deserializer = NumpyDeserializer(dtype="U5") |
| 97 | + stream = io.BytesIO(b'[["hello",2,3],\n[4,5,6]]') |
| 98 | + array = numpy_deserializer.deserialize(stream, "application/json") |
| 99 | + assert np.array_equal(array, np.array([["hello", 2, 3], [4, 5, 6]])) |
| 100 | + |
| 101 | + |
| 102 | +def test_numpy_deserializer_from_npy(numpy_deserializer): |
| 103 | + array = np.ones((2, 3)) |
| 104 | + stream = io.BytesIO() |
| 105 | + np.save(stream, array) |
| 106 | + stream.seek(0) |
| 107 | + |
| 108 | + result = numpy_deserializer.deserialize(stream, "application/x-npy") |
| 109 | + |
| 110 | + assert np.array_equal(array, result) |
| 111 | + |
| 112 | + |
| 113 | +def test_numpy_deserializer_from_npy_object_array(numpy_deserializer): |
| 114 | + array = np.array(["one", "two"]) |
| 115 | + stream = io.BytesIO() |
| 116 | + np.save(stream, array) |
| 117 | + stream.seek(0) |
| 118 | + |
| 119 | + result = numpy_deserializer.deserialize(stream, "application/x-npy") |
| 120 | + |
| 121 | + assert np.array_equal(array, result) |
0 commit comments