diff --git a/src/sagemaker/base_serializers.py b/src/sagemaker/base_serializers.py index f40cb068bd..45fea23493 100644 --- a/src/sagemaker/base_serializers.py +++ b/src/sagemaker/base_serializers.py @@ -397,6 +397,8 @@ def serialize(self, data): raise ValueError(f"Could not open/read file: {data}. {e}") if isinstance(data, bytes): return data + if isinstance(data, dict) and "data" in data: + return self.serialize(data["data"]) raise ValueError(f"Object of type {type(data)} is not Data serializable.") diff --git a/src/sagemaker/serve/builder/schema_builder.py b/src/sagemaker/serve/builder/schema_builder.py index d0f65716d8..2a0d4892e4 100644 --- a/src/sagemaker/serve/builder/schema_builder.py +++ b/src/sagemaker/serve/builder/schema_builder.py @@ -164,6 +164,11 @@ def _get_serializer(self, obj): return StringSerializer() if _is_jsonable(obj): return JSONSerializerWrapper() + if isinstance(obj, dict) and "content_type" in obj: + try: + return DataSerializer(content_type=obj["content_type"]) + except ValueError as e: + logger.error(e) raise ValueError( ( diff --git a/src/sagemaker/serve/builder/transformers_builder.py b/src/sagemaker/serve/builder/transformers_builder.py index 2d4cefcb4f..3d84e314df 100644 --- a/src/sagemaker/serve/builder/transformers_builder.py +++ b/src/sagemaker/serve/builder/transformers_builder.py @@ -94,7 +94,7 @@ def _create_transformers_model(self) -> Type[Model]: ) hf_config = image_uris.config_for_framework("huggingface").get("inference") config = hf_config["versions"] - base_hf_version = sorted(config.keys(), key=lambda v: Version(v))[0] + base_hf_version = sorted(config.keys(), key=lambda v: Version(v), reverse=True)[0] if hf_model_md is None: raise ValueError("Could not fetch HF metadata") @@ -269,7 +269,7 @@ def _get_supported_version(self, hf_config, hugging_face_version, base_fw): if len(hugging_face_version.split(".")) == 2: base_fw_version = ".".join(base_fw_version.split(".")[:-1]) versions_to_return.append(base_fw_version) - return sorted(versions_to_return)[0] + return sorted(versions_to_return, reverse=True)[0] def _build_for_transformers(self): """Method that triggers model build diff --git a/tests/integ/sagemaker/serve/test_schema_builder.py b/tests/integ/sagemaker/serve/test_schema_builder.py index 3fd72fb371..a0c1673ae8 100644 --- a/tests/integ/sagemaker/serve/test_schema_builder.py +++ b/tests/integ/sagemaker/serve/test_schema_builder.py @@ -202,6 +202,33 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode( ), f"{caught_ex} was thrown when running transformers sagemaker endpoint test" +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Schema Builder Simplification feature - Remote Schema", +) +@pytest.mark.parametrize( + "model_id, task_provided, instance_type_provided", + [("openai/whisper-tiny.en", "automatic-speech-recognition", "ml.m5.4xlarge")], +) +def test_model_builder_with_task_provided_remote_schema_mode_asr( + model_id, task_provided, sagemaker_session, instance_type_provided +): + model_builder = ModelBuilder( + model=model_id, + model_metadata={"HF_TASK": task_provided}, + instance_type=instance_type_provided, + ) + model = model_builder.build(sagemaker_session=sagemaker_session) + + assert model is not None + assert model_builder.schema_builder is not None + + remote_hf_schema_helper = remote_schema_retriever.RemoteSchemaRetriever() + inputs, outputs = remote_hf_schema_helper.get_resolved_hf_schema_for_task(task_provided) + assert model_builder.schema_builder.sample_input == inputs + assert model_builder.schema_builder.sample_output == outputs + + def test_model_builder_negative_path_with_invalid_task(sagemaker_session): model_builder = ModelBuilder( model="bert-base-uncased", model_metadata={"HF_TASK": "invalid-task"}