Skip to content

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def default(self, obj):
8787
return int(obj)
8888
elif isinstance(obj, np.floating):
8989
return float(obj)
90-
elif isinstance(obj, np.ndarray):
90+
elif isinstance(obj, np.ndarray) or hasattr(obj, "tolist"):
9191
return obj.tolist()
9292
elif isinstance(obj, datetime.datetime):
9393
return obj.__str__()

tests/unit/test_decoder_encoder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from mms.service import PredictionException
2020
from PIL import Image
2121
from sagemaker_huggingface_inference_toolkit import decoder_encoder
22+
from transformers.testing_utils import require_torch
2223

2324

2425
ENCODE_JSON_INPUT = {"upper": [1425], "lower": [576], "level": [2], "datetime": ["2012-08-08 15:30"]}
@@ -83,6 +84,12 @@ def test_encode_json():
8384
encoded_data = decoder_encoder.encode_json(ENCODE_JSON_INPUT)
8485
assert json.loads(encoded_data) == ENCODE_JSON_INPUT
8586

87+
@require_torch
88+
def test_encode_json_torch():
89+
import torch
90+
DATA=[1, 0.5, 5.0]
91+
encoded_data = decoder_encoder.encode_json({"data": torch.tensor(DATA)})
92+
assert json.loads(encoded_data) == {"data": DATA}
8693

8794
def test_encode_csv():
8895
decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT)

0 commit comments

Comments
 (0)