Skip to content

Add csv support #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.
import datetime
import json

from io import StringIO
import csv
import numpy as np
from sagemaker_inference.decoder import (
_npy_to_numpy,
_csv_to_numpy,
_npz_to_sparse,
)
from sagemaker_inference.encoder import (
_array_to_npy,
_array_to_csv,
)
from sagemaker_inference import (
content_types,
Expand All @@ -34,6 +33,22 @@ def decode_json(content):
return json.loads(content)


def decode_csv(string_like): # type: (str) -> np.array
"""Convert a CSV object to a dictonary with list attributes.

Args:
string_like (str): CSV string.
Returns:
(dict): dictonatry for input
"""
stream = StringIO(string_like)
request_list = list(csv.DictReader(stream))
if "inputs" in request_list[0].keys():
return {"inputs": [entry["inputs"] for entry in request_list]}
else:
return {"inputs": request_list}


# https://github.com/automl/SMAC3/issues/453
class _JSONEncoder(json.JSONEncoder):
"""
Expand Down Expand Up @@ -67,14 +82,33 @@ def encode_json(content):
)


def encode_csv(content): # type: (str) -> np.array
"""Convert the result of a transformers pipeline to CSV.
Args:
content (dict | list): result of transformers pipeline.
Returns:
(str): object serialized to CSV
"""
stream = StringIO()
if not isinstance(content, list):
content = list(content)

column_header = content[0].keys()
writer = csv.DictWriter(stream, column_header)

writer.writeheader()
writer.writerows(content)
return stream.getvalue()


_encoder_map = {
content_types.NPY: _array_to_npy,
content_types.CSV: _array_to_csv,
content_types.CSV: encode_csv,
content_types.JSON: encode_json,
}
_decoder_map = {
content_types.NPY: _npy_to_numpy,
content_types.CSV: _csv_to_numpy,
content_types.CSV: decode_csv,
content_types.NPZ: _npz_to_sparse,
content_types.JSON: decode_json,
}
Expand Down
17 changes: 14 additions & 3 deletions src/sagemaker_huggingface_inference_toolkit/handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ def preprocess(self, input_data, content_type):
Returns:
decoded_input_data (dict): deserialized input_data into a Python dictonary.
"""
# raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
if (
os.environ["HF_TASK"] == "zero-shot-classification" or os.environ["HF_TASK"] == "table-question-answering"
) and content_type == content_types.CSV:
raise PredictionException(
f"content type {content_type} not support with {os.environ['HF_TASK']}, use different content_type",
400,
)

decoded_input_data = decoder_encoder.decode(input_data, content_type)
return decoded_input_data

Expand Down Expand Up @@ -182,9 +191,11 @@ def transform_fn(self, model, input_data, content_type, accept):
predict_time = time.time() - preprocess_time
response = self.postprocess(predictions, accept)

logger.info(f"Preprocess time - {preprocess_time * 1000} ms\n"
f"Predict time - {predict_time * 1000} ms\n"
f"Postprocess time - {(time.time() - predict_time) * 1000} ms")
logger.info(
f"Preprocess time - {preprocess_time * 1000} ms\n"
f"Predict time - {predict_time * 1000} ms\n"
f"Postprocess time - {(time.time() - predict_time) * 1000} ms"
)

return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def infer_task_from_model_architecture(model_config_path: str, architecture_inde
f"Inference Toolkit can only inference tasks from architectures ending with {list(ARCHITECTURES_2_TASK.keys())}."
"Use env `HF_TASK` to define your task."
)
# set env to work with
os.environ["HF_TASK"] = task
return task


Expand All @@ -211,6 +213,8 @@ def infer_task_from_hub(model_id: str, revision: Optional[str] = None, use_auth_
_api = HfApi()
model_info = _api.model_info(repo_id=model_id, revision=revision, token=use_auth_token)
if model_info.pipeline_tag is not None:
# set env to work with
os.environ["HF_TASK"] = model_info.pipeline_tag
return model_info.pipeline_tag
else:
raise ValueError(
Expand Down
47 changes: 37 additions & 10 deletions tests/unit/test_decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,47 @@
from sagemaker_huggingface_inference_toolkit import decoder_encoder


ENCODE_INPUT = {"upper": [1425], "lower": [576], "level": [2], "datetime": ["2012-08-08 15:30"]}
DECODE_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
ENCODE_JSON_INPUT = {"upper": [1425], "lower": [576], "level": [2], "datetime": ["2012-08-08 15:30"]}
ENCODE_CSV_INPUT = [
{"answer": "Nuremberg", "end": 42, "score": 0.9926825761795044, "start": 33},
{"answer": "Berlin is the capital of Germany", "end": 32, "score": 0.26097726821899414, "start": 0},
]

DECODE_JSON_INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
DECODE_CSV_INPUT = "question,context\r\nwhere do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"

CONTENT_TYPE = "application/json"


def test_decode_json():
decoded_data = decoder_encoder.decode_json(json.dumps(DECODE_INPUT))
assert decoded_data == DECODE_INPUT
decoded_data = decoder_encoder.decode_json(json.dumps(DECODE_JSON_INPUT))
assert decoded_data == DECODE_JSON_INPUT


def test_decode_csv():
decoded_data = decoder_encoder.decode_csv(DECODE_CSV_INPUT)
assert decoded_data == {
"inputs": [
{"question": "where do i live?", "context": "My name is Philipp and I live in Nuremberg"},
{"question": "where is Berlin?", "context": "Berlin is the capital of Germany"},
]
}
text_classification_input = "inputs\r\nI love you\r\nI like you"
decoded_data = decoder_encoder.decode_csv(DECODE_CSV_INPUT)
assert decoded_data == {"inputs": ["I love you", "I like you"]}


def test_encode_json():
encoded_data = decoder_encoder.encode_json(ENCODE_INPUT)
assert json.loads(encoded_data) == ENCODE_INPUT
encoded_data = decoder_encoder.encode_json(ENCODE_JSON_INPUT)
assert json.loads(encoded_data) == ENCODE_JSON_INPUT


def test_encode_csv():
decoded_data = decoder_encoder.encode_csv(ENCODE_CSV_INPUT)
assert (
decoded_data
== "answer,end,score,start\r\nNuremberg,42,0.9926825761795044,33\r\nBerlin is the capital of Germany,32,0.26097726821899414,0\r\n"
)


def test_decode_content_type():
Expand All @@ -43,10 +70,10 @@ def test_encode_content_type():


def test_decode():
decode = decoder_encoder.decode(json.dumps(DECODE_INPUT), CONTENT_TYPE)
assert decode == DECODE_INPUT
decode = decoder_encoder.decode(json.dumps(DECODE_JSON_INPUT), CONTENT_TYPE)
assert decode == DECODE_JSON_INPUT


def test_encode():
encode = decoder_encoder.encode(ENCODE_INPUT, CONTENT_TYPE)
assert json.loads(encode) == ENCODE_INPUT
encode = decoder_encoder.encode(ENCODE_JSON_INPUT, CONTENT_TYPE)
assert json.loads(encode) == ENCODE_JSON_INPUT