Skip to content

Commit 4d1f901

Browse files
spoornMichael Trinh
and
Michael Trinh
authored
feature: Add support for JSON model inputs for Clarify Processor (#3768)
Co-authored-by: Michael Trinh <[email protected]>
1 parent f03b2ed commit 4d1f901

File tree

3 files changed

+181
-19
lines changed

3 files changed

+181
-19
lines changed

src/sagemaker/clarify.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@
282282
in (
283283
"text/csv",
284284
"application/jsonlines",
285+
"application/json",
285286
"image/jpeg",
286287
"image/png",
287288
"application/x-npy",
@@ -296,6 +297,7 @@
296297
SchemaOptional("probability"): Or(str, int),
297298
SchemaOptional("label_headers"): [Or(str, int)],
298299
SchemaOptional("content_template"): Or(str, {str: str}),
300+
SchemaOptional("record_template"): str,
299301
SchemaOptional("custom_attributes"): str,
300302
},
301303
}
@@ -573,6 +575,7 @@ def __init__(
573575
accept_type: Optional[str] = None,
574576
content_type: Optional[str] = None,
575577
content_template: Optional[str] = None,
578+
record_template: Optional[str] = None,
576579
custom_attributes: Optional[str] = None,
577580
accelerator_type: Optional[str] = None,
578581
endpoint_name_prefix: Optional[str] = None,
@@ -599,14 +602,80 @@ def __init__(
599602
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
600603
Default is the same as ``content_type``.
601604
content_type (str): The model input format to be used for getting inferences with the
602-
shadow endpoint. Valid values are ``"text/csv"`` for CSV and
603-
``"application/jsonlines"`` for JSON Lines. Default is the same as
604-
``dataset_format``.
605+
shadow endpoint. Valid values are ``"text/csv"`` for CSV,
606+
``"application/jsonlines"`` for JSON Lines, and ``"application/json"`` for JSON.
607+
Default is the same as ``dataset_format``.
605608
content_template (str): A template string to be used to construct the model input from
606-
dataset instances. It is only used when ``model_content_type`` is
607-
``"application/jsonlines"``. The template should have one and only one placeholder,
608-
``"features"``, which will be replaced by a features list to form the model
609-
inference input.
609+
dataset instances. It is only used, and required, when ``model_content_type`` is
610+
``"application/jsonlines"`` or ``"application/json"``. When ``model_content_type``
611+
is ``application/jsonlines``, the template should have one and only one
612+
placeholder, ``$features``, which will be replaced by a features list for each
613+
record to form the model inference input. When ``model_content_type`` is
614+
``application/json``, the template can have either placeholder ``$record``, which
615+
will be replaced by a single record templated by ``record_template`` and only a
616+
single record at a time will be sent to the model, or placeholder ``$records``,
617+
which will be replaced by a list of records, each templated by ``record_template``.
618+
record_template (str): A template string to be used to construct each record of the
619+
model input from dataset instances. It is only used, and required, when
620+
``model_content_type`` is ``"application/json"``.
621+
The template string may contain one of the following:
622+
623+
* Placeholder ``$features`` that will be substituted by the array of feature values
624+
and/or an optional placeholder ``$feature_names`` that will be substituted by the
625+
array of feature names.
626+
* Exactly one placeholder ``$features_kvp`` that will be substituted by the
627+
key-value pairs of feature name and feature value.
628+
* Or for each feature, if "A" is the feature name in the ``headers`` configuration,
629+
then placeholder syntax ``"${A}"`` (the double-quotes are part of the
630+
placeholder) will be substituted by the feature value.
631+
632+
``record_template`` will be used in conjunction with ``content_template`` to
633+
construct the model input.
634+
635+
**Examples:**
636+
637+
Given:
638+
639+
* ``headers``: ``["A", "B"]``
640+
* ``features``: ``[[0, 1], [3, 4]]``
641+
642+
Example model input 1::
643+
644+
{
645+
"instances": [[0, 1], [3, 4]],
646+
"feature_names": ["A", "B"]
647+
}
648+
649+
content_template and record_template to construct above:
650+
651+
* ``content_template``: ``"{\"instances\": $records}"``
652+
* ``record_template``: ``"$features"``
653+
654+
Example model input 2::
655+
656+
[
657+
{ "A": 0, "B": 1 },
658+
{ "A": 3, "B": 4 },
659+
]
660+
661+
content_template and record_template to construct above:
662+
663+
* ``content_template``: ``"$records"``
664+
* ``record_template``: ``"$features_kvp"``
665+
666+
Or, alternatively:
667+
668+
* ``content_template``: ``"$records"``
669+
* ``record_template``: ``"{\"A\": \"${A}\", \"B\": \"${B}\"}"``
670+
671+
Example model input 3 (single record only)::
672+
673+
{ "A": 0, "B": 1 }
674+
675+
content_template and record_template to construct above:
676+
677+
* ``content_template``: ``"$record"``
678+
* ``record_template``: ``"$features_kvp"``
610679
custom_attributes (str): Provides additional information about a request for an
611680
inference submitted to a model hosted at an Amazon SageMaker endpoint. The
612681
information is an opaque value that is forwarded verbatim. You could use this
@@ -677,6 +746,7 @@ def __init__(
677746
if content_type not in [
678747
"text/csv",
679748
"application/jsonlines",
749+
"application/json",
680750
"image/jpeg",
681751
"image/jpg",
682752
"image/png",
@@ -686,14 +756,32 @@ def __init__(
686756
f"Invalid content_type {content_type}."
687757
f" Please choose text/csv or application/jsonlines."
688758
)
759+
if content_type == "application/jsonlines":
760+
if content_template is None:
761+
raise ValueError(
762+
f"content_template field is required for content_type {content_type}"
763+
)
764+
if "$features" not in content_template:
765+
raise ValueError(
766+
f"Invalid content_template {content_template}."
767+
f" Please include a placeholder $features."
768+
)
769+
if content_type == "application/json":
770+
if content_template is None or record_template is None:
771+
raise ValueError(
772+
f"content_template and record_template are required for content_type "
773+
f"{content_type}"
774+
)
775+
if "$record" not in content_template:
776+
raise ValueError(
777+
f"Invalid content_template {content_template}."
778+
f" Please include either placeholder $records or $record."
779+
)
689780
self.predictor_config["content_type"] = content_type
690781
if content_template is not None:
691-
if "$features" not in content_template:
692-
raise ValueError(
693-
f"Invalid content_template {content_template}."
694-
f" Please include a placeholder $features."
695-
)
696782
self.predictor_config["content_template"] = content_template
783+
if record_template is not None:
784+
self.predictor_config["record_template"] = record_template
697785
_set(custom_attributes, "custom_attributes", self.predictor_config)
698786
_set(accelerator_type, "accelerator_type", self.predictor_config)
699787
_set(target_model, "target_model", self.predictor_config)

tests/unit/sagemaker/monitor/test_clarify_model_monitor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@
365365
MODEL_NAME = "xgboost-model"
366366
ACCEPT_TYPE = "text/csv"
367367
CONTENT_TYPE = "application/jsonlines"
368+
JSONLINES_CONTENT_TEMPLATE = '{"instances":$features}'
368369
EXPLAINABILITY_ANALYSIS_CONFIG = {
369370
"headers": ANALYSIS_CONFIG_HEADERS_OF_FEATURES,
370371
"methods": {
@@ -382,6 +383,7 @@
382383
"initial_instance_count": INSTANCE_COUNT,
383384
"accept_type": ACCEPT_TYPE,
384385
"content_type": CONTENT_TYPE,
386+
"content_template": JSONLINES_CONTENT_TEMPLATE,
385387
},
386388
}
387389
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy.deepcopy(EXPLAINABILITY_ANALYSIS_CONFIG)
@@ -489,6 +491,7 @@ def model_config():
489491
instance_count=INSTANCE_COUNT,
490492
content_type=CONTENT_TYPE,
491493
accept_type=ACCEPT_TYPE,
494+
content_template=JSONLINES_CONTENT_TEMPLATE,
492495
)
493496

494497

tests/unit/test_clarify.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ def test_facet_of_bias_config(facet_name, facet_values_or_threshold, expected_re
393393
("text/csv", "application/json"),
394394
("application/jsonlines", "application/json"),
395395
("application/jsonlines", "text/csv"),
396+
("application/json", "application/json"),
397+
("application/json", "application/jsonlines"),
398+
("application/json", "text/csv"),
396399
("image/jpeg", "text/csv"),
397400
("image/jpg", "text/csv"),
398401
("image/png", "text/csv"),
@@ -406,12 +409,22 @@ def test_valid_model_config(content_type, accept_type):
406409
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
407410
target_model = "target_model_name"
408411
accelerator_type = "ml.eia1.medium"
412+
content_template = (
413+
'{"instances":$features}'
414+
if content_type == "application/jsonlines"
415+
else "$records"
416+
if content_type == "application/json"
417+
else None
418+
)
419+
record_template = "$features_kvp" if content_type == "application/json" else None
409420
model_config = ModelConfig(
410421
model_name=model_name,
411422
instance_type=instance_type,
412423
instance_count=instance_count,
413424
accept_type=accept_type,
414425
content_type=content_type,
426+
content_template=content_template,
427+
record_template=record_template,
415428
custom_attributes=custom_attributes,
416429
accelerator_type=accelerator_type,
417430
target_model=target_model,
@@ -426,21 +439,79 @@ def test_valid_model_config(content_type, accept_type):
426439
"accelerator_type": accelerator_type,
427440
"target_model": target_model,
428441
}
442+
if content_template is not None:
443+
expected_config["content_template"] = content_template
444+
if record_template is not None:
445+
expected_config["record_template"] = record_template
429446
assert expected_config == model_config.get_predictor_config()
430447

431448

432-
def test_invalid_model_config():
433-
with pytest.raises(ValueError) as error:
449+
@pytest.mark.parametrize(
450+
("error", "content_type", "accept_type", "content_template", "record_template"),
451+
[
452+
(
453+
"Invalid accept_type invalid_accept_type. Please choose text/csv or application/jsonlines.",
454+
"text/csv",
455+
"invalid_accept_type",
456+
None,
457+
None,
458+
),
459+
(
460+
"Invalid content_type invalid_content_type. Please choose text/csv or application/jsonlines.",
461+
"invalid_content_type",
462+
"text/csv",
463+
None,
464+
None,
465+
),
466+
(
467+
"content_template field is required for content_type",
468+
"application/jsonlines",
469+
"text/csv",
470+
None,
471+
None,
472+
),
473+
(
474+
"content_template and record_template are required for content_type",
475+
"application/json",
476+
"text/csv",
477+
None,
478+
None,
479+
),
480+
(
481+
"content_template and record_template are required for content_type",
482+
"application/json",
483+
"text/csv",
484+
"$records",
485+
None,
486+
),
487+
(
488+
r"Invalid content_template invalid_content_template. Please include a placeholder \$features.",
489+
"application/jsonlines",
490+
"text/csv",
491+
"invalid_content_template",
492+
None,
493+
),
494+
(
495+
r"Invalid content_template invalid_content_template. Please include either placeholder "
496+
r"\$records or \$record.",
497+
"application/json",
498+
"text/csv",
499+
"invalid_content_template",
500+
"$features",
501+
),
502+
],
503+
)
504+
def test_invalid_model_config(error, content_type, accept_type, content_template, record_template):
505+
with pytest.raises(ValueError, match=error):
434506
ModelConfig(
435507
model_name="xgboost-model",
436508
instance_type="ml.c5.xlarge",
437509
instance_count=1,
438-
accept_type="invalid_accept_type",
510+
content_type=content_type,
511+
accept_type=accept_type,
512+
content_template=content_template,
513+
record_template=record_template,
439514
)
440-
assert (
441-
"Invalid accept_type invalid_accept_type. Please choose text/csv or application/jsonlines."
442-
in str(error.value)
443-
)
444515

445516

446517
def test_invalid_model_config_with_bad_endpoint_name_prefix():

0 commit comments

Comments
 (0)