diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ce0eec2665..017e099ae1 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -115,6 +115,19 @@ def __init__( self._enable_network_isolation = enable_network_isolation self.model_kms_key = model_kms_key + def _init_sagemaker_session_if_does_not_exist(self, instance_type): + """Set ``self.sagemaker_session`` to be a ``LocalSession`` or + ``Session`` if it is not already. The type of session object is + determined by the instance type. + """ + if self.sagemaker_session: + return + + if instance_type in ("local", "local_gpu"): + self.sagemaker_session = local.LocalSession() + else: + self.sagemaker_session = session.Session() + def prepare_container_def( self, instance_type, accelerator_type=None ): # pylint: disable=unused-argument @@ -164,6 +177,8 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=Non container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type) self.name = self.name or utils.name_from_image(container_def["Image"]) enable_network_isolation = self.enable_network_isolation() + + self._init_sagemaker_session_if_does_not_exist(instance_type) self.sagemaker_session.create_model( self.name, self.role, @@ -324,6 +339,7 @@ def compile( framework = framework.upper() framework_version = self._get_framework_version() or framework_version + self._init_sagemaker_session_if_does_not_exist(target_instance_family) config = self._compilation_job_config( target_instance_family, input_shape, @@ -413,11 +429,7 @@ def deploy( ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` is not None. Otherwise, return None. """ - if not self.sagemaker_session: - if instance_type in ("local", "local_gpu"): - self.sagemaker_session = local.LocalSession() - else: - self.sagemaker_session = session.Session() + self._init_sagemaker_session_if_does_not_exist(instance_type) if self.role is None: raise ValueError("Role can not be null for deploying a model") @@ -514,6 +526,8 @@ def transformer( volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML compute instance (default: None). """ + self._init_sagemaker_session_if_does_not_exist(instance_type) + self._create_sagemaker_model(instance_type, tags=tags) if self.enable_network_isolation(): env = None diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 22d56ac117..254eefc088 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -435,6 +435,21 @@ def test_model_create_transformer(sagemaker_session): sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags) +@patch("sagemaker.session.Session") +@patch("sagemaker.local.LocalSession") +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +def test_transformer_creates_correct_session(local_session, session): + model = DummyFrameworkModel(sagemaker_session=None) + transformer = model.transformer(instance_count=1, instance_type="local") + assert model.sagemaker_session == local_session.return_value + assert transformer.sagemaker_session == local_session.return_value + + model = DummyFrameworkModel(sagemaker_session=None) + transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") + assert model.sagemaker_session == session.return_value + assert transformer.sagemaker_session == session.return_value + + def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): sagemaker_session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE @@ -561,6 +576,24 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir): assert model._is_compiled_model is True +@patch("sagemaker.session.Session") +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +def test_compile_creates_session(session): + session.return_value.boto_region_name = "us-west-2" + + model = DummyFrameworkModel(sagemaker_session=None) + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) + + assert model.sagemaker_session == session.return_value + + def test_check_neo_region(sagemaker_session, tmpdir): sagemaker_session.wait_for_compilation_job = Mock( return_value=DESCRIBE_COMPILATION_JOB_RESPONSE