Skip to content

Commit 2743a73

Browse files
authored
Merge pull request #25 from philschmid/add-error-for-csv-without-header
Add exception for CSV without headers.
2 parents 2fb97f9 + 4a74cf1 commit 2743a73

File tree

4 files changed

+31
-16
lines changed

4 files changed

+31
-16
lines changed

src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import csv
1415
import datetime
1516
import json
1617
from io import StringIO
17-
import csv
18+
1819
import numpy as np
19-
from sagemaker_inference.decoder import (
20-
_npy_to_numpy,
21-
_npz_to_sparse,
22-
)
23-
from sagemaker_inference.encoder import (
24-
_array_to_npy,
25-
)
26-
from sagemaker_inference import (
27-
content_types,
28-
errors,
29-
)
20+
from sagemaker_inference import content_types, errors
21+
from sagemaker_inference.decoder import _npy_to_numpy, _npz_to_sparse
22+
from sagemaker_inference.encoder import _array_to_npy
23+
24+
from mms.service import PredictionException
3025

3126

3227
def decode_json(content):
@@ -42,6 +37,13 @@ def decode_csv(string_like): # type: (str) -> np.array
4237
(dict): dictonatry for input
4338
"""
4439
stream = StringIO(string_like)
40+
# detects if the incoming csv has headers
41+
if not any(header in string_like.splitlines()[0].lower() for header in ["question", "context", "inputs"]):
42+
raise PredictionException(
43+
f"You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
44+
400,
45+
)
46+
# reads csv as io
4547
request_list = list(csv.DictReader(stream))
4648
if "inputs" in request_list[0].keys():
4749
return {"inputs": [entry["inputs"] for entry in request_list]}
@@ -123,6 +125,8 @@ def decode(content, content_type=content_types.JSON):
123125
return decoder(content)
124126
except KeyError:
125127
raise errors.UnsupportedFormatError(content_type)
128+
except PredictionException as pred_err:
129+
raise pred_err
126130

127131

128132
def encode(content, content_type=content_types.JSON):

src/sagemaker_huggingface_inference_toolkit/handler_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
import importlib
1616
import logging
1717
import os
18-
import time
1918
import sys
19+
import time
2020
from abc import ABC
2121

22-
from sagemaker_inference import environment, utils, content_types
22+
from sagemaker_inference import content_types, environment, utils
2323
from transformers.pipelines import SUPPORTED_TASKS
2424

25-
from mms.service import PredictionException
2625
from mms import metrics
26+
from mms.service import PredictionException
2727
from sagemaker_huggingface_inference_toolkit import decoder_encoder
2828
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
2929
_is_gpu_available,

tests/unit/test_decoder_encoder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414
import json
1515

16+
import pytest
17+
18+
from mms.service import PredictionException
1619
from sagemaker_huggingface_inference_toolkit import decoder_encoder
1720

1821

@@ -46,6 +49,13 @@ def test_decode_csv():
4649
assert decoded_data == {"inputs": ["I love you", "I like you"]}
4750

4851

52+
def test_decode_csv_without_header():
53+
with pytest.raises(PredictionException):
54+
decoder_encoder.decode_csv(
55+
"where do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"
56+
)
57+
58+
4959
def test_encode_json():
5060
encoded_data = decoder_encoder.encode_json(ENCODE_JSON_INPUT)
5161
assert json.loads(encoded_data) == ENCODE_JSON_INPUT

tests/unit/test_handler_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import tempfile
1818

1919
import pytest
20+
from sagemaker_inference import content_types
2021
from transformers.testing_utils import require_torch, slow
2122

2223
from mms.context import Context, RequestProcessor
2324
from mms.metrics.metrics_store import MetricsStore
2425
from sagemaker_huggingface_inference_toolkit import handler_service
2526
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline
26-
from sagemaker_inference import content_types
27+
2728

2829
TASK = "text-classification"
2930
MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"

0 commit comments

Comments
 (0)