Skip to content

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def default(self, obj):
8787
return int(obj)
8888
elif isinstance(obj, np.floating):
8989
return float(obj)
90-
elif isinstance(obj, np.ndarray):
90+
elif isinstance(obj, np.ndarray) or hasattr(obj, "tolist"):
9191
return obj.tolist()
9292
elif isinstance(obj, datetime.datetime):
9393
return obj.__str__()

tests/unit/test_decoder_encoder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mms.service import PredictionException
2020
from PIL import Image
2121
from sagemaker_huggingface_inference_toolkit import decoder_encoder
22+
from transformers.testing_utils import require_torch
2223

2324

2425
ENCODE_JSON_INPUT = {"upper": [1425], "lower": [576], "level": [2], "datetime": ["2012-08-08 15:30"]}
@@ -84,6 +85,15 @@ def test_encode_json():
8485
assert json.loads(encoded_data) == ENCODE_JSON_INPUT
8586

8687

88+
@require_torch
89+
def test_encode_json_torch():
90+
import torch
91+
92+
DATA = [1, 0.5, 5.0]
93+
encoded_data = decoder_encoder.encode_json({"data": torch.tensor(DATA)})
94+
assert json.loads(encoded_data) == {"data": DATA}
95+
96+
8797
def test_encode_csv():
8898
decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT)
8999
assert (

0 commit comments

Comments
 (0)