diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 5c69566b14..0ab00c3c2e 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -17,7 +17,7 @@ import subprocess import sagemaker -from sagemaker.model import FrameworkModel, ModelPackage +from sagemaker.model import FrameworkModel, Model, ModelPackage from sagemaker.predictor import RealTimePredictor import pytest @@ -417,37 +417,100 @@ def test_model_enable_network_isolation(sagemaker_session): assert model.enable_network_isolation() is False -@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) -def test_model_create_transformer(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE - ) +@patch("sagemaker.model.Model._create_sagemaker_model") +def test_model_create_transformer(create_sagemaker_model, sagemaker_session): + model_name = "auto-generated-model" + model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session) + + instance_type = "ml.m4.xlarge" + transformer = model.transformer(instance_count=1, instance_type=instance_type) + + create_sagemaker_model.assert_called_with(instance_type, tags=None) + + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.model_name == model_name + assert transformer.instance_type == instance_type + assert transformer.instance_count == 1 + assert transformer.sagemaker_session == sagemaker_session + assert transformer.base_transform_job_name == model_name + + assert transformer.strategy is None + assert transformer.env is None + assert transformer.output_path is None + assert transformer.output_kms_key is None + assert transformer.accept is None + assert transformer.assemble_with is None + assert transformer.volume_kms_key is None + assert transformer.max_concurrent_transforms is None + assert transformer.max_payload is None + assert transformer.tags is None + + +@patch("sagemaker.model.Model._create_sagemaker_model") +def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session): + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) - tags = [{"Key": "k", "Value": "v"}] - model = DummyFrameworkModel(sagemaker_session=sagemaker_session) instance_type = "ml.m4.xlarge" - model.name = "auto-generated-model" + strategy = "MultiRecord" + assemble_with = "Line" + output_path = "s3://bucket/path" + kms_key = "key" + accept = "text/csv" + env = {"test": True} + max_concurrent_transforms = 1 + max_payload = 6 + tags = [{"Key": "k", "Value": "v"}] + transformer = model.transformer( - instance_count=1, instance_type=instance_type, env={"test": True}, tags=tags + instance_count=1, + instance_type=instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=kms_key, + accept=accept, + env=env, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + tags=tags, + volume_kms_key=kms_key, ) + + create_sagemaker_model.assert_called_with(instance_type, tags=tags) + assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == "auto-generated-model" - assert transformer.instance_type == "ml.m4.xlarge" - assert transformer.env == {"test": True} + assert transformer.strategy == strategy + assert transformer.assemble_with == assemble_with + assert transformer.output_path == output_path + assert transformer.output_kms_key == kms_key + assert transformer.accept == accept + assert transformer.max_concurrent_transforms == max_concurrent_transforms + assert transformer.max_payload == max_payload + assert transformer.env == env + assert transformer.tags == tags + assert transformer.volume_kms_key == kms_key + + +@patch("sagemaker.model.Model._create_sagemaker_model") +def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): + model = Model( + MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True + ) - sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags) + transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"}) + assert transformer.env is None @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_transformer_creates_correct_session(local_session, session): - model = DummyFrameworkModel(sagemaker_session=None) + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) transformer = model.transformer(instance_count=1, instance_type="local") assert model.sagemaker_session == local_session.return_value assert transformer.sagemaker_session == local_session.return_value - model = DummyFrameworkModel(sagemaker_session=None) + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") assert model.sagemaker_session == session.return_value assert transformer.sagemaker_session == session.return_value