Skip to content

Commit cfe8f9f

Browse files
author
Michael Trinh
committed
feature: Add support for JSON model inputs for Clarify Processor
1 parent f25beb5 commit cfe8f9f

File tree

2 files changed

+133
-19
lines changed

2 files changed

+133
-19
lines changed

src/sagemaker/clarify.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@
282282
in (
283283
"text/csv",
284284
"application/jsonlines",
285+
"application/json",
285286
"image/jpeg",
286287
"image/png",
287288
"application/x-npy",
@@ -296,6 +297,7 @@
296297
SchemaOptional("probability"): Or(str, int),
297298
SchemaOptional("label_headers"): [Or(str, int)],
298299
SchemaOptional("content_template"): Or(str, {str: str}),
300+
SchemaOptional("record_template"): str,
299301
SchemaOptional("custom_attributes"): str,
300302
},
301303
}
@@ -573,6 +575,7 @@ def __init__(
573575
accept_type: Optional[str] = None,
574576
content_type: Optional[str] = None,
575577
content_template: Optional[str] = None,
578+
record_template: Optional[str] = None,
576579
custom_attributes: Optional[str] = None,
577580
accelerator_type: Optional[str] = None,
578581
endpoint_name_prefix: Optional[str] = None,
@@ -599,14 +602,80 @@ def __init__(
599602
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
600603
Default is the same as ``content_type``.
601604
content_type (str): The model input format to be used for getting inferences with the
602-
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
603-
``"application/jsonlines"`` for JSON Lines. Default is the same as
604-
``dataset_format``.
605+
shadow endpoint. Valid values are ``"text/csv"`` for CSV,
606+
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
607+
Default is the same as ``dataset_format``.
605608
content_template (str): A template string to be used to construct the model input from
606-
dataset instances. It is only used when ``model_content_type`` is
607-
``"application/jsonlines"``. The template should have one and only one placeholder,
608-
``"features"``, which will be replaced by a features list to form the model
609-
inference input.
609+
dataset instances. It is only used, and required, when ``model_content_type`` is
610+
``"application/jsonlines"`` or ``"application/json"``. When ``model_content_type``
611+
is ``application/jsonlines``, the template should have one and only one
612+
placeholder, ``$features``, which will be replaced by a features list for each
613+
record to form the model inference input. When ``model_content_type`` is
614+
``application/json``, the template can have either placeholder ``$record``, which
615+
will be replaced by a single record templated by ``record_template`` and only a
616+
single record at a time will be sent to the model, or placeholder ``$records``,
617+
which will be replaced by a list of records, each templated by ``record_template``.
618+
record_template (str): A template string to be used to construct each record of the
619+
model input from dataset instances. It is only used, and required, when
620+
``model_content_type`` is ``"application/json"``.
621+
The template string may contain one of the following:
622+
623+
* Placeholder ``$features`` that will be substituted by the array of feature values
624+
and/or an optional placeholder ``$feature_names`` that will be substituted by the
625+
array of feature names.
626+
* Exactly one placeholder ``$features_kvp`` that will be substituted by the
627+
key-value pairs of feature name and feature value.
628+
* Or for each feature, if "A" is the feature name in the ``headers`` configuration,
629+
then placeholder syntax ``"${A}"`` (the double-quotes are part of the
630+
placeholder) will be substituted by the feature value.
631+
632+
``record_template`` will be used in conjunction with ``content_template`` to
633+
construct the model input.
634+
635+
**Examples:**
636+
637+
Given:
638+
639+
* ``headers``: ``["A", "B"]``
640+
* ``features``: ``[[0, 1], [3, 4]]``
641+
642+
Example model input 1::
643+
644+
{
645+
"instances": [[0, 1], [3, 4]],
646+
"feature_names": ["A", "B"]
647+
}
648+
649+
content_template and record_template to construct above:
650+
651+
* ``content_template``: ``"{\"instances\": $records}"``
652+
* ``record_template``: ``"$features"``
653+
654+
Example model input 2::
655+
656+
[
657+
{ "A": 0, "B": 1 },
658+
{ "A": 3, "B": 4 },
659+
]
660+
661+
content_template and record_template to construct above:
662+
663+
* ``content_template``: ``"$records"``
664+
* ``record_template``: ``"$features_kvp"``
665+
666+
Or, alternatively:
667+
668+
* ``content_template``: ``"$records"``
669+
* ``record_template``: ``"{\"A\": \"${A}\", \"B\": \"${B}\"}"``
670+
671+
Example model input 3 (single record only)::
672+
673+
{ "A": 0, "B": 1 }
674+
675+
content_template and record_template to construct above:
676+
677+
* ``content_template``: ``"$record"``
678+
* ``record_template``: ``"$features_kvp"``
610679
custom_attributes (str): Provides additional information about a request for an
611680
inference submitted to a model hosted at an Amazon SageMaker endpoint. The
612681
information is an opaque value that is forwarded verbatim. You could use this
@@ -677,6 +746,7 @@ def __init__(
677746
if content_type not in [
678747
"text/csv",
679748
"application/jsonlines",
749+
"application/json",
680750
"image/jpeg",
681751
"image/jpg",
682752
"image/png",
@@ -686,14 +756,32 @@ def __init__(
686756
f"Invalid content_type {content_type}."
687757
f" Please choose text/csv or application/jsonlines."
688758
)
759+
if content_type == "application/jsonlines":
760+
if content_template is None:
761+
raise ValueError(
762+
f"content_template field is required for content_type {content_type}"
763+
)
764+
if "$features" not in content_template:
765+
raise ValueError(
766+
f"Invalid content_template {content_template}."
767+
f" Please include a placeholder $features."
768+
)
769+
if content_type == "application/json":
770+
if content_template is None or record_template is None:
771+
raise ValueError(
772+
f"content_template and record_template are required for content_type "
773+
f"{content_type}"
774+
)
775+
if "$record" not in content_template:
776+
raise ValueError(
777+
f"Invalid content_template {content_template}."
778+
f" Please include either placeholder $records or $record."
779+
)
689780
self.predictor_config["content_type"] = content_type
690781
if content_template is not None:
691-
if "$features" not in content_template:
692-
raise ValueError(
693-
f"Invalid content_template {content_template}."
694-
f" Please include a placeholder $features."
695-
)
696782
self.predictor_config["content_template"] = content_template
783+
if record_template is not None:
784+
self.predictor_config["record_template"] = record_template
697785
_set(custom_attributes, "custom_attributes", self.predictor_config)
698786
_set(accelerator_type, "accelerator_type", self.predictor_config)
699787
_set(target_model, "target_model", self.predictor_config)

tests/unit/test_clarify.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ def test_facet_of_bias_config(facet_name, facet_values_or_threshold, expected_re
393393
("text/csv", "application/json"),
394394
("application/jsonlines", "application/json"),
395395
("application/jsonlines", "text/csv"),
396+
("application/json", "application/json"),
397+
("application/json", "application/jsonlines"),
398+
("application/json", "text/csv"),
396399
("image/jpeg", "text/csv"),
397400
("image/jpg", "text/csv"),
398401
("image/png", "text/csv"),
@@ -406,12 +409,20 @@ def test_valid_model_config(content_type, accept_type):
406409
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
407410
target_model = "target_model_name"
408411
accelerator_type = "ml.eia1.medium"
412+
content_template = (
413+
'{"instances":$features}'
414+
if content_type == "application/jsonlines"
415+
else '$records' if content_type == "application/json" else None
416+
)
417+
record_template = '$features_kvp' if content_type == "application/json" else None
409418
model_config = ModelConfig(
410419
model_name=model_name,
411420
instance_type=instance_type,
412421
instance_count=instance_count,
413422
accept_type=accept_type,
414423
content_type=content_type,
424+
content_template=content_template,
425+
record_template=record_template,
415426
custom_attributes=custom_attributes,
416427
accelerator_type=accelerator_type,
417428
target_model=target_model,
@@ -426,21 +437,36 @@ def test_valid_model_config(content_type, accept_type):
426437
"accelerator_type": accelerator_type,
427438
"target_model": target_model,
428439
}
440+
if content_template is not None:
441+
expected_config["content_template"] = content_template
442+
if record_template is not None:
443+
expected_config["record_template"] = record_template
429444
assert expected_config == model_config.get_predictor_config()
430445

431446

432-
def test_invalid_model_config():
433-
with pytest.raises(ValueError) as error:
447+
@pytest.mark.parametrize(
448+
("error", "content_type", "accept_type", "content_template", "record_template"),
449+
[
450+
("Invalid accept_type invalid_accept_type. Please choose text/csv or application/jsonlines.", "text/csv", "invalid_accept_type", None, None),
451+
("Invalid content_type invalid_content_type. Please choose text/csv or application/jsonlines.", "invalid_content_type", "text/csv", None, None),
452+
("content_template field is required for content_type", "application/jsonlines", "text/csv", None, None),
453+
("content_template and record_template are required for content_type", "application/json", "text/csv", None, None),
454+
("content_template and record_template are required for content_type", "application/json", "text/csv", "$records", None),
455+
("Invalid content_template invalid_content_template. Please include a placeholder \$features.", "application/jsonlines", "text/csv", "invalid_content_template", None),
456+
("Invalid content_template invalid_content_template. Please include either placeholder \$records or \$record.", "application/json", "text/csv", "invalid_content_template", "$features"),
457+
]
458+
)
459+
def test_invalid_model_config(error, content_type, accept_type, content_template, record_template):
460+
with pytest.raises(ValueError, match=error):
434461
ModelConfig(
435462
model_name="xgboost-model",
436463
instance_type="ml.c5.xlarge",
437464
instance_count=1,
438-
accept_type="invalid_accept_type",
465+
content_type=content_type,
466+
accept_type=accept_type,
467+
content_template=content_template,
468+
record_template=record_template
439469
)
440-
assert (
441-
"Invalid accept_type invalid_accept_type. Please choose text/csv or application/jsonlines."
442-
in str(error.value)
443-
)
444470

445471

446472
def test_invalid_model_config_with_bad_endpoint_name_prefix():

0 commit comments

Comments
 (0)