Skip to content

Commit 02ba035

Browse files
authored
Add support for JSON encoding torch.tensor to keep it consistent with sagemaker-inference-toolkit (#84)
* Add support for JSON encoding `torch.tensor` to keep it consistent with sagemaker-inference-toolkit Sagemaker base toolkit: https://github.com/aws/sagemaker-inference-toolkit/blob/e602335fd9a4db08216d1f58ded2861cccb64f7d/src/sagemaker_inference/encoder.py#L25_L44 HF inference toolkit: https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/27275f40a2bbff85bb507646e6a3ef866d0599af/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L80_L114 * Address review comments
1 parent 27275f4 commit 02ba035

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def default(self, obj):
8787
return int(obj)
8888
elif isinstance(obj, np.floating):
8989
return float(obj)
90-
elif isinstance(obj, np.ndarray):
90+
elif hasattr(obj, "tolist"):
9191
return obj.tolist()
9292
elif isinstance(obj, datetime.datetime):
9393
return obj.__str__()

tests/unit/test_decoder_encoder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import json
1515
import os
1616

17+
import numpy as np
1718
import pytest
19+
from transformers.testing_utils import require_torch
1820

1921
from mms.service import PredictionException
2022
from PIL import Image
@@ -26,6 +28,7 @@
2628
{"answer": "Nuremberg", "end": 42, "score": 0.9926825761795044, "start": 33},
2729
{"answer": "Berlin is the capital of Germany", "end": 32, "score": 0.26097726821899414, "start": 0},
2830
]
31+
ENCODE_TOLOIST_INPUT = [1, 0.5, 5.0]
2932

3033
DECODE_JSON_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
3134
DECODE_CSV_INPUT = "question,context\r\nwhere do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"
@@ -84,6 +87,19 @@ def test_encode_json():
8487
assert json.loads(encoded_data) == ENCODE_JSON_INPUT
8588

8689

90+
@require_torch
91+
def test_encode_json_torch():
92+
import torch
93+
94+
encoded_data = decoder_encoder.encode_json({"data": torch.tensor(ENCODE_TOLOIST_INPUT)})
95+
assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT}
96+
97+
98+
def test_encode_json_numpy():
99+
encoded_data = decoder_encoder.encode_json({"data": np.array(ENCODE_TOLOIST_INPUT)})
100+
assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT}
101+
102+
87103
def test_encode_csv():
88104
decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT)
89105
assert (

0 commit comments

Comments
 (0)