diff --git a/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py b/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py index 25d7765..00eac62 100644 --- a/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py +++ b/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py @@ -87,7 +87,7 @@ def default(self, obj): return int(obj) elif isinstance(obj, np.floating): return float(obj) - elif isinstance(obj, np.ndarray): + elif hasattr(obj, "tolist"): return obj.tolist() elif isinstance(obj, datetime.datetime): return obj.__str__() diff --git a/tests/unit/test_decoder_encoder.py b/tests/unit/test_decoder_encoder.py index 704c49f..a4d0284 100644 --- a/tests/unit/test_decoder_encoder.py +++ b/tests/unit/test_decoder_encoder.py @@ -14,7 +14,9 @@ import json import os +import numpy as np import pytest +from transformers.testing_utils import require_torch from mms.service import PredictionException from PIL import Image @@ -26,6 +28,7 @@ {"answer": "Nuremberg", "end": 42, "score": 0.9926825761795044, "start": 33}, {"answer": "Berlin is the capital of Germany", "end": 32, "score": 0.26097726821899414, "start": 0}, ] +ENCODE_TOLOIST_INPUT = [1, 0.5, 5.0] DECODE_JSON_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"} DECODE_CSV_INPUT = "question,context\r\nwhere do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany" @@ -84,6 +87,19 @@ def test_encode_json(): assert json.loads(encoded_data) == ENCODE_JSON_INPUT +@require_torch +def test_encode_json_torch(): + import torch + + encoded_data = decoder_encoder.encode_json({"data": torch.tensor(ENCODE_TOLOIST_INPUT)}) + assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT} + + +def test_encode_json_numpy(): + encoded_data = decoder_encoder.encode_json({"data": np.array(ENCODE_TOLOIST_INPUT)}) + assert json.loads(encoded_data) == {"data": ENCODE_TOLOIST_INPUT} + + def test_encode_csv(): decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT) assert (