Skip to content

Commit c4c9da0

Browse files
committed
Address review comments
1 parent 4edf510 commit c4c9da0

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def default(self, obj):
8787
return int(obj)
8888
elif isinstance(obj, np.floating):
8989
return float(obj)
90-
elif isinstance(obj, np.ndarray) or hasattr(obj, "tolist"):
90+
elif hasattr(obj, "tolist"):
9191
return obj.tolist()
9292
elif isinstance(obj, datetime.datetime):
9393
return obj.__str__()

tests/unit/test_decoder_encoder.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import os
1616

17+
import numpy as np
1718
import pytest
1819
from transformers.testing_utils import require_torch
1920

@@ -27,6 +28,7 @@
2728
{"answer": "Nuremberg", "end": 42, "score": 0.9926825761795044, "start": 33},
2829
{"answer": "Berlin is the capital of Germany", "end": 32, "score": 0.26097726821899414, "start": 0},
2930
]
31+
ENCODE_TOLOIST_INPUT = [1, 0.5, 5.0]
3032

3133
DECODE_JSON_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
3234
DECODE_CSV_INPUT = "question,context\r\nwhere do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"
@@ -89,9 +91,13 @@ def test_encode_json():
8991
def test_encode_json_torch():
9092
import torch
9193

92-
DATA = [1, 0.5, 5.0]
93-
encoded_data = decoder_encoder.encode_json({"data": torch.tensor(DATA)})
94-
assert json.loads(encoded_data) == {"data": DATA}
94+
encoded_data = decoder_encoder.encode_json({"data": torch.tensor(ENCODE_TOLOIST_INPUT)})
95+
assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT}
96+
97+
98+
def test_encode_json_numpy():
99+
encoded_data = decoder_encoder.encode_json({"data": np.array(ENCODE_TOLOIST_INPUT)})
100+
assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT}
95101

96102

97103
def test_encode_csv():

0 commit comments

Comments
 (0)