diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index a20f102e56..20ac8c8ffa 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1737,21 +1737,41 @@ def register( @property def model_data(self): - """str: The model location in S3. Only set if Estimator has been ``fit()``.""" + """Str or dict: The model location in S3. Only set if Estimator has been ``fit()``.""" if self.latest_training_job is not None and not isinstance( self.sagemaker_session, PipelineSession ): - model_uri = self.sagemaker_session.sagemaker_client.describe_training_job( + job_details = self.sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=self.latest_training_job.name - )["ModelArtifacts"]["S3ModelArtifacts"] - else: - logger.warning( - "No finished training job found associated with this estimator. Please make sure " - "this estimator is only used for building workflow config" ) - model_uri = os.path.join( - self.output_path, self._current_job_name, "output", "model.tar.gz" + model_uri = job_details["ModelArtifacts"]["S3ModelArtifacts"] + compression_type = job_details.get("OutputDataConfig", {}).get( + "CompressionType", "GZIP" ) + if compression_type == "GZIP": + return model_uri + # fail fast if we don't recognize training output compression type + if compression_type not in {"GZIP", "NONE"}: + raise ValueError( + f'Unrecognized training job output data compression type "{compression_type}"' + ) + # model data is in uncompressed form NOTE SageMaker Hosting mandates presence of + # trailing forward slash in S3 model data URI, so append one if necessary. + if not model_uri.endswith("/"): + model_uri += "/" + return { + "S3DataSource": { + "S3Uri": model_uri, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + + logger.warning( + "No finished training job found associated with this estimator. Please make sure " + "this estimator is only used for building workflow config" + ) + model_uri = os.path.join(self.output_path, self._current_job_name, "output", "model.tar.gz") return model_uri @abstractmethod diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 525abe3a04..7ccd7a220d 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -131,7 +131,7 @@ class Model(ModelBase, InferenceRecommenderMixin): def __init__( self, image_uri: Union[str, PipelineVariable], - model_data: Optional[Union[str, PipelineVariable]] = None, + model_data: Optional[Union[str, PipelineVariable, dict]] = None, role: Optional[str] = None, predictor_cls: Optional[callable] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, @@ -152,8 +152,8 @@ def __init__( Args: image_uri (str or PipelineVariable): A Docker image URI. - model_data (str or PipelineVariable): The S3 location of a SageMaker - model data ``.tar.gz`` file (default: None). + model_data (str or PipelineVariable or dict): Location + of SageMaker model data (default: None). role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model @@ -455,6 +455,11 @@ def register( """ if self.model_data is None: raise ValueError("SageMaker Model Package cannot be created without model data.") + if isinstance(self.model_data, dict): + raise ValueError( + "SageMaker Model Package currently cannot be created with ModelDataSource." + ) + if image_uri is not None: self.image_uri = image_uri @@ -600,6 +605,7 @@ def prepare_container_def( ) self._upload_code(deploy_key_prefix, repack=is_repack) deploy_env.update(self._script_mode_env_vars()) + return sagemaker.container_def( self.image_uri, self.repacked_model_data or self.model_data, @@ -639,6 +645,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: ) if repack and self.model_data is not None and self.entry_point is not None: + if isinstance(self.model_data, dict): + logging.warning("ModelDataSource currently doesn't support model repacking") + return if is_pipeline_variable(self.model_data): # model is not yet there, defer repacking to later during pipeline execution if not isinstance(self.sagemaker_session, PipelineSession): @@ -765,10 +774,16 @@ def _create_sagemaker_model( # _base_name, model_name are not needed under PipelineSession. # the model_data may be Pipeline variable # which may break the _base_name generation + model_uri = None + if isinstance(self.model_data, (str, PipelineVariable)): + model_uri = self.model_data + elif isinstance(self.model_data, dict): + model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None) + self._ensure_base_name_if_needed( image_uri=container_def["Image"], script_uri=self.source_dir, - model_uri=self.model_data, + model_uri=model_uri, ) self._set_model_name_if_needed() @@ -1110,6 +1125,8 @@ def compile( raise ValueError("You must provide a compilation job name") if self.model_data is None: raise ValueError("You must provide an S3 path to the compressed model artifacts.") + if isinstance(self.model_data, dict): + raise ValueError("Compiling model data from ModelDataSource is currently not supported") framework_version = framework_version or self._get_framework_version() @@ -1301,7 +1318,7 @@ def deploy( tags = add_jumpstart_tags( tags=tags, - inference_model_uri=self.model_data, + inference_model_uri=self.model_data if isinstance(self.model_data, str) else None, inference_script_uri=self.source_dir, ) @@ -1545,7 +1562,7 @@ class FrameworkModel(Model): def __init__( self, - model_data: Union[str, PipelineVariable], + model_data: Union[str, PipelineVariable, dict], image_uri: Union[str, PipelineVariable], role: Optional[str] = None, entry_point: Optional[str] = None, @@ -1563,8 +1580,8 @@ def __init__( """Initialize a ``FrameworkModel``. Args: - model_data (str or PipelineVariable): The S3 location of a SageMaker - model data ``.tar.gz`` file. + model_data (str or PipelineVariable or dict): The S3 location of + SageMaker model data. image_uri (str or PipelineVariable): A Docker image URI. role (str): An IAM role name or ARN for SageMaker to access AWS resources on your behalf. @@ -1758,6 +1775,11 @@ def __init__( ``model_data`` is not required. **kwargs: Additional kwargs passed to the Model constructor. """ + if isinstance(model_data, dict): + raise ValueError( + "Creating ModelPackage with ModelDataSource is currently not supported" + ) + super(ModelPackage, self).__init__( role=role, model_data=model_data, image_uri=None, **kwargs ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 7fcea802f4..9946660ed2 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3584,7 +3584,7 @@ def create_model_from_job( ) primary_container = container_def( image_uri or training_job["AlgorithmSpecification"]["TrainingImage"], - model_data_url=model_data_url or training_job["ModelArtifacts"]["S3ModelArtifacts"], + model_data_url=model_data_url or self._gen_s3_model_data_source(training_job), env=env, ) vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override) @@ -4486,14 +4486,14 @@ def endpoint_from_job( str: Name of the ``Endpoint`` that is created. """ job_desc = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) - output_url = job_desc["ModelArtifacts"]["S3ModelArtifacts"] + model_s3_location = self._gen_s3_model_data_source(job_desc) image_uri = image_uri or job_desc["AlgorithmSpecification"]["TrainingImage"] role = role or job_desc["RoleArn"] name = name or job_name vpc_config_override = _vpc_config_from_training_job(job_desc, vpc_config_override) return self.endpoint_from_model_data( - model_s3_location=output_url, + model_s3_location=model_s3_location, image_uri=image_uri, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -4506,6 +4506,40 @@ def endpoint_from_job( data_capture_config=data_capture_config, ) + def _gen_s3_model_data_source(self, training_job_spec): + """Generates ``ModelDataSource`` value from given DescribeTrainingJob API response. + + Args: + training_job_spec (dict): SageMaker DescribeTrainingJob API response. + + Returns: + dict: A ``ModelDataSource`` value. + """ + model_data_s3_uri = training_job_spec["ModelArtifacts"]["S3ModelArtifacts"] + compression_type = training_job_spec.get("OutputDataConfig", {}).get( + "CompressionType", "GZIP" + ) + # See https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_OutputDataConfig.html + # and https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3ModelDataSource.html + if compression_type in {"NONE", "GZIP"}: + model_compression_type = compression_type.title() + else: + raise ValueError( + f'Unrecognized training job output data compression type "{compression_type}"' + ) + s3_model_data_type = "S3Object" if model_compression_type == "Gzip" else "S3Prefix" + # if model data is in S3Prefix type and has no trailing forward slash in its URI, + # append one so that it meets SageMaker Hosting's mandate for deploying uncompressed model. + if s3_model_data_type == "S3Prefix" and not model_data_s3_uri.endswith("/"): + model_data_s3_uri += "/" + return { + "S3DataSource": { + "S3Uri": model_data_s3_uri, + "S3DataType": s3_model_data_type, + "CompressionType": model_compression_type, + } + } + def endpoint_from_model_data( self, model_s3_location, @@ -4524,7 +4558,8 @@ def endpoint_from_model_data( """Create and deploy to an ``Endpoint`` using existing model data stored in S3. Args: - model_s3_location (str): S3 URI of the model artifacts to use for the endpoint. + model_s3_location (str or dict): S3 location of the model artifacts + to use for the endpoint. image_uri (str): The Docker image URI which defines the runtime code to be used as the entry point for accepting prediction requests. initial_instance_count (int): Minimum number of EC2 instances to launch. The actual @@ -5925,8 +5960,10 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, Args: image_uri (str): Docker image URI to run for this container. - model_data_url (str): S3 URI of data required by this container, - e.g. SageMaker training job model artifacts (default: None). + model_data_url (str or dict[str, Any]): S3 location of model data required by this + container, e.g. SageMaker training job model artifacts. It can either be a string + representing S3 URI of model data, or a dictionary representing a + ``ModelDataSource`` object. (default: None). env (dict[str, str]): Environment variables to set inside the container (default: None). container_mode (str): The model container mode. Valid modes: * MultiModel: Indicates that model container can support hosting multiple models @@ -5943,8 +5980,12 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, if env is None: env = {} c_def = {"Image": image_uri, "Environment": env} - if model_data_url: + + if isinstance(model_data_url, dict): + c_def["ModelDataSource"] = model_data_url + elif model_data_url: c_def["ModelDataUrl"] = model_data_url + if container_mode: c_def["Mode"] = container_mode if image_config: diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 5553b9d1da..b79202e4ea 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -810,6 +810,39 @@ def test_register_calls_model_package_args(get_model_package_args, sagemaker_ses get_model_package_args""" +def test_register_calls_model_data_source_not_supported(sagemaker_session): + source_dir = "s3://blah/blah/blah" + t = Model( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + source_dir=source_dir, + image_uri=IMAGE_URI, + model_data={ + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + ) + + with pytest.raises( + ValueError, + match="SageMaker Model Package currently cannot be created with ModelDataSource.", + ): + t.register( + SUPPORTED_CONTENT_TYPES, + SUPPORTED_RESPONSE_MIME_TYPES, + SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES, + SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES, + marketplace_cert=True, + description=MODEL_DESCRIPTION, + model_package_name=MODEL_NAME, + validation_specification=VALIDATION_SPECIFICATION, + ) + + @patch("sagemaker.utils.repack_model") def test_model_local_download_dir(repack_model, sagemaker_session): diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index d769d8e119..ca29644603 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -209,6 +209,24 @@ def test_create_sagemaker_model_include_tags(sagemaker_session): ) +def test_model_package_model_data_source_not_supported(sagemaker_session): + with pytest.raises( + ValueError, match="Creating ModelPackage with ModelDataSource is currently not supported" + ): + ModelPackage( + role="role", + model_package_arn="my-model-package", + model_data={ + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + sagemaker_session=sagemaker_session, + ) + + @patch("sagemaker.utils.name_from_base") def test_create_sagemaker_model_generates_model_name(name_from_base, sagemaker_session): model_package_name = "my-model-package" diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index dce71e6002..f4fbbb39c2 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -244,6 +244,31 @@ def test_compile_validates_model_data(): assert "You must provide an S3 path to the compressed model artifacts." in str(e) +def test_compile_validates_model_data_source(): + model_data_src = { + "S3DataSource": { + "S3Uri": "s3://bucket/model/prefix", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + model = Model(MODEL_IMAGE, model_data=model_data_src) + + with pytest.raises( + ValueError, match="Compiling model data from ModelDataSource is currently not supported" + ) as e: + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) + + assert "Compiling model data from ModelDataSource is currently not supported" in str(e) + + def test_deploy_honors_provided_model_name(sagemaker_session): model = _create_model(sagemaker_session) model._is_compiled_model = True diff --git a/tests/unit/test_endpoint_from_job.py b/tests/unit/test_endpoint_from_job.py index 9cffe4d438..3e38234667 100644 --- a/tests/unit/test_endpoint_from_job.py +++ b/tests/unit/test_endpoint_from_job.py @@ -23,6 +23,13 @@ ACCELERATOR_TYPE = "ml.eia.medium" IMAGE = "myimage" S3_MODEL_ARTIFACTS = "s3://mybucket/mymodel" +S3_MODEL_SRC_COMPRESSED = { + "S3DataSource": { + "S3Uri": S3_MODEL_ARTIFACTS, + "S3DataType": "S3Object", + "CompressionType": "Gzip", + } +} TRAIN_ROLE = "mytrainrole" VPC_CONFIG = {"Subnets": ["subnet-foo"], "SecurityGroupIds": ["sg-foo"]} TRAINING_JOB_RESPONSE = { @@ -68,7 +75,7 @@ def test_all_defaults_no_existing_entities(sagemaker_session): expected_args = original_args.copy() expected_args.pop("job_name") - expected_args["model_s3_location"] = S3_MODEL_ARTIFACTS + expected_args["model_s3_location"] = S3_MODEL_SRC_COMPRESSED expected_args["image_uri"] = IMAGE expected_args["role"] = TRAIN_ROLE expected_args["name"] = JOB_NAME @@ -100,7 +107,7 @@ def test_no_defaults_no_existing_entities(sagemaker_session): expected_args = original_args.copy() expected_args.pop("job_name") - expected_args["model_s3_location"] = S3_MODEL_ARTIFACTS + expected_args["model_s3_location"] = S3_MODEL_SRC_COMPRESSED expected_args["model_vpc_config"] = expected_args.pop("vpc_config_override") expected_args["data_capture_config"] = None sagemaker_session.endpoint_from_model_data.assert_called_once_with(**expected_args) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 8d8015b305..5e1e4d2645 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -108,6 +108,17 @@ DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} +DESCRIBE_TRAINING_JOB_RESULT_UNCOMPRESSED_S3_MODEL = { + "ModelArtifacts": { + "S3ModelArtifacts": "s3://bucket/model/prefix", + }, + "OutputDataConfig": { + "CompressionType": "NONE", + "KmsKeyId": "outputkms", + "S3OutputPath": "s3://path/to/model", + }, +} + RETURNED_JOB_DESCRIPTION = { "AlgorithmSpecification": { "TrainingInputMode": "File", @@ -3312,6 +3323,106 @@ def test_fit_deploy_tags(name_from_base, sagemaker_session): ) +@patch("sagemaker.estimator.name_from_base") +def test_fit_deploy_uncompressed_s3_model(name_from_base, sagemaker_session): + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", + return_value=DESCRIBE_TRAINING_JOB_RESULT_UNCOMPRESSED_S3_MODEL, + ) + estimator = Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) + + estimator.fit() + + model_name = "model_name" + name_from_base.return_value = model_name + + estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE) + + variant = [ + { + "InstanceType": "c4.4xlarge", + "VariantName": "AllTraffic", + "ModelName": model_name, + "InitialVariantWeight": 1, + "InitialInstanceCount": 1, + } + ] + + name_from_base.assert_called_with(IMAGE_URI) + + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=model_name, + production_variants=variant, + kms_key=None, + wait=True, + data_capture_config_dict=None, + async_inference_config_dict=None, + explainer_config_dict=None, + tags=None, + ) + + sagemaker_session.create_model.assert_called_with( + name=model_name, + role="DummyRole", + container_defs={ + "ModelDataSource": { + "S3DataSource": { + # S3 URI passed to Createmodel API should have trailing forward slash appeneded + "S3Uri": "s3://bucket/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + "Environment": {}, + "Image": "fakeimage", + }, + enable_network_isolation=False, + vpc_config=None, + tags=None, + ) + + +@patch("sagemaker.estimator.name_from_base") +def test_fit_deploy_uncompressed_s3_model_unrecognized_compression_type( + name_from_base, sagemaker_session +): + training_job_desc = deepcopy(DESCRIBE_TRAINING_JOB_RESULT_UNCOMPRESSED_S3_MODEL) + training_job_desc["OutputDataConfig"]["CompressionType"] = "JUNK" + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", + return_value=training_job_desc, + ) + estimator = Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, + ) + + estimator.fit() + + model_name = "model_name" + name_from_base.return_value = model_name + + with pytest.raises( + ValueError, + match='Unrecognized training job output data compression type "JUNK"', + ): + estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE) + + name_from_base.assert_called_with(IMAGE_URI) + + sagemaker_session.endpoint_from_production_variants.assert_not_called() + sagemaker_session.create_model.assert_not_called() + + @patch("time.time", return_value=TIME) def test_generic_to_fit_no_input(time, sagemaker_session): e = Estimator( diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 7f79deed33..464d2f65a8 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -990,6 +990,13 @@ def test_training_input_all_arguments(): {"TrainingEndTime": datetime.datetime(2018, 2, 17, 7, 19, 34, 953000)} ) +COMPLETED_DESCRIBE_JOB_RESULT_UNCOMPRESSED_S3_MODEL = copy.deepcopy(COMPLETED_DESCRIBE_JOB_RESULT) +COMPLETED_DESCRIBE_JOB_RESULT_UNCOMPRESSED_S3_MODEL["ModelArtifacts"]["S3ModelArtifacts"] = ( + S3_OUTPUT + "/model/prefix" +) +COMPLETED_DESCRIBE_JOB_RESULT_UNCOMPRESSED_S3_MODEL["OutputDataConfig"]["CompressionType"] = "NONE" + + STOPPED_DESCRIBE_JOB_RESULT = dict(COMPLETED_DESCRIBE_JOB_RESULT) STOPPED_DESCRIBE_JOB_RESULT.update({"TrainingJobStatus": "Stopped"}) @@ -2618,6 +2625,29 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ "Image": IMAGE, "ModelDataUrl": "s3://sagemaker-123/output/jobname/model/model.tar.gz", } +PRIMARY_CONTAINER_WITH_COMPRESSED_S3_MODEL = { + "Environment": {}, + "Image": IMAGE, + "ModelDataSource": { + "S3DataSource": { + "S3Uri": "s3://sagemaker-123/output/jobname/model/model.tar.gz", + "S3DataType": "S3Object", + "CompressionType": "Gzip", + } + }, +} +PRIMARY_CONTAINER_WITH_UNCOMPRESSED_S3_MODEL = { + "Environment": {}, + "Image": IMAGE, + "ModelDataSource": { + "S3DataSource": { + # expect model data URI has trailing forward slash appended + "S3Uri": "s3://sagemaker-123/output/jobname/model/prefix/", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, +} def test_create_model_with_sagemaker_config_injection_with_primary_container(sagemaker_session): @@ -2814,7 +2844,7 @@ def test_create_model_from_job(sagemaker_session): ims.sagemaker_client.create_model.assert_called_with( ExecutionRoleArn=EXPANDED_ROLE, ModelName=JOB_NAME, - PrimaryContainer=PRIMARY_CONTAINER, + PrimaryContainer=PRIMARY_CONTAINER_WITH_COMPRESSED_S3_MODEL, VpcConfig=VPC_CONFIG, ) @@ -2830,12 +2860,50 @@ def test_create_model_from_job_with_tags(sagemaker_session): ims.sagemaker_client.create_model.assert_called_with( ExecutionRoleArn=EXPANDED_ROLE, ModelName=JOB_NAME, - PrimaryContainer=PRIMARY_CONTAINER, + PrimaryContainer=PRIMARY_CONTAINER_WITH_COMPRESSED_S3_MODEL, VpcConfig=VPC_CONFIG, Tags=TAGS, ) +def test_create_model_from_job_uncompressed_s3_model(sagemaker_session): + ims = sagemaker_session + ims.sagemaker_client.describe_training_job.return_value = ( + COMPLETED_DESCRIBE_JOB_RESULT_UNCOMPRESSED_S3_MODEL + ) + ims.create_model_from_job(JOB_NAME) + + assert ( + call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list + ) + ims.sagemaker_client.create_model.assert_called_with( + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=JOB_NAME, + PrimaryContainer=PRIMARY_CONTAINER_WITH_UNCOMPRESSED_S3_MODEL, + VpcConfig=VPC_CONFIG, + ) + + +def test_create_model_from_job_uncompressed_s3_model_unrecognized_compression_type( + sagemaker_session, +): + ims = sagemaker_session + job_desc = copy.deepcopy(COMPLETED_DESCRIBE_JOB_RESULT_UNCOMPRESSED_S3_MODEL) + job_desc["OutputDataConfig"]["CompressionType"] = "JUNK" + ims.sagemaker_client.describe_training_job.return_value = job_desc + + with pytest.raises( + ValueError, match='Unrecognized training job output data compression type "JUNK"' + ): + ims.create_model_from_job(JOB_NAME) + + assert ( + call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list + ) + + ims.sagemaker_client.create_model.assert_not_called() + + def test_create_edge_packaging_with_sagemaker_config_injection(sagemaker_session): sagemaker_session.sagemaker_config = SAGEMAKER_CONFIG_EDGE_PACKAGING_JOB