Skip to content

Commit e117c76

Browse files
authored
feature: Add sparse matrix serializer (#1739)
1 parent 4746868 commit e117c76

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/sagemaker/serializers.py

+29
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020

2121
import numpy as np
2222

23+
from sagemaker.utils import DeferredError
24+
25+
try:
26+
import scipy
27+
except ImportError as e:
28+
scipy = DeferredError(e)
29+
2330

2431
class BaseSerializer(abc.ABC):
2532
"""Abstract base class for creation of new serializers.
@@ -183,3 +190,25 @@ def serialize(self, data):
183190
return json.dumps(data.tolist())
184191

185192
return json.dumps(data)
193+
194+
195+
class SparseMatrixSerializer(BaseSerializer):
196+
"""Serialize a sparse matrix to a buffer using the .npz format."""
197+
198+
CONTENT_TYPE = "application/x-npz"
199+
200+
def serialize(self, data):
201+
"""Serialize a sparse matrix to a buffer using the .npz format.
202+
203+
Sparse matrices can be in the ``csc``, ``csr``, ``bsr``, ``dia`` or
204+
``coo`` formats.
205+
206+
Args:
207+
data (scipy.sparse.spmatrix): The sparse matrix to serialize.
208+
209+
Returns:
210+
io.BytesIO: A buffer containing the serialized sparse matrix.
211+
"""
212+
buffer = io.BytesIO()
213+
scipy.sparse.save_npz(buffer, data)
214+
return buffer.getvalue()

tests/unit/sagemaker/test_serializers.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@
1818

1919
import numpy as np
2020
import pytest
21-
22-
from sagemaker.serializers import CSVSerializer, NumpySerializer, JSONSerializer
21+
import scipy
22+
23+
from sagemaker.serializers import (
24+
CSVSerializer,
25+
NumpySerializer,
26+
JSONSerializer,
27+
SparseMatrixSerializer,
28+
)
2329
from tests.unit import DATA_DIR
2430

2531

@@ -227,3 +233,16 @@ def test_json_serializer_csv_buffer(json_serializer):
227233
csv_file.seek(0)
228234
result = json_serializer.serialize(csv_file)
229235
assert result == validation_value
236+
237+
238+
@pytest.fixture
239+
def sparse_matrix_serializer():
240+
return SparseMatrixSerializer()
241+
242+
243+
def test_sparse_matrix_serializer(sparse_matrix_serializer):
244+
data = scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]]))
245+
stream = io.BytesIO(sparse_matrix_serializer.serialize(data))
246+
result = scipy.sparse.load_npz(stream).toarray()
247+
expected = data.toarray()
248+
assert np.array_equal(result, expected)

0 commit comments

Comments
 (0)