Skip to content

fix: create Session or LocalSession if not specified in Model #1288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def __init__(
self._enable_network_isolation = enable_network_isolation
self.model_kms_key = model_kms_key

def _init_sagemaker_session_if_does_not_exist(self, instance_type):
"""Set ``self.sagemaker_session`` to be a ``LocalSession`` or
``Session`` if it is not already. The type of session object is
determined by the instance type.
"""
if self.sagemaker_session:
return

if instance_type in ("local", "local_gpu"):
self.sagemaker_session = local.LocalSession()
else:
self.sagemaker_session = session.Session()

def prepare_container_def(
self, instance_type, accelerator_type=None
): # pylint: disable=unused-argument
Expand Down Expand Up @@ -164,6 +177,8 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=Non
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
self.name = self.name or utils.name_from_image(container_def["Image"])
enable_network_isolation = self.enable_network_isolation()

self._init_sagemaker_session_if_does_not_exist(instance_type)
self.sagemaker_session.create_model(
self.name,
self.role,
Expand Down Expand Up @@ -210,6 +225,7 @@ def _compilation_job_config(
else json.dumps(input_shape),
"Framework": framework,
}

role = self.sagemaker_session.expand_role(role)
output_model_config = {
"TargetDevice": target_instance_type,
Expand Down Expand Up @@ -324,6 +340,7 @@ def compile(
framework = framework.upper()
framework_version = self._get_framework_version() or framework_version

self._init_sagemaker_session_if_does_not_exist(target_instance_family)
config = self._compilation_job_config(
target_instance_family,
input_shape,
Expand Down Expand Up @@ -413,11 +430,7 @@ def deploy(
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
is not None. Otherwise, return None.
"""
if not self.sagemaker_session:
if instance_type in ("local", "local_gpu"):
self.sagemaker_session = local.LocalSession()
else:
self.sagemaker_session = session.Session()
self._init_sagemaker_session_if_does_not_exist(instance_type)

if self.role is None:
raise ValueError("Role can not be null for deploying a model")
Expand Down Expand Up @@ -514,6 +527,8 @@ def transformer(
volume_kms_key (str): Optional. KMS key ID for encrypting the volume
attached to the ML compute instance (default: None).
"""
self._init_sagemaker_session_if_does_not_exist(instance_type)

self._create_sagemaker_model(instance_type, tags=tags)
if self.enable_network_isolation():
env = None
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,21 @@ def test_model_create_transformer(sagemaker_session):
sagemaker.model.Model._create_sagemaker_model.assert_called_with(instance_type, tags=tags)


@patch("sagemaker.session.Session")
@patch("sagemaker.local.LocalSession")
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
def test_transformer_creates_correct_session(local_session, session):
model = DummyFrameworkModel(sagemaker_session=None)
transformer = model.transformer(instance_count=1, instance_type="local")
assert model.sagemaker_session == local_session.return_value
assert transformer.sagemaker_session == local_session.return_value

model = DummyFrameworkModel(sagemaker_session=None)
transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
assert model.sagemaker_session == session.return_value
assert transformer.sagemaker_session == session.return_value


def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session):
sagemaker_session.sagemaker_client.describe_model_package = Mock(
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
Expand Down Expand Up @@ -561,6 +576,24 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
assert model._is_compiled_model is True


@patch("sagemaker.session.Session")
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
def test_compile_creates_session(session):
session.return_value.boto_region_name = "us-west-2"

model = DummyFrameworkModel(sagemaker_session=None)
model.compile(
target_instance_family="ml_c4",
input_shape={"data": [1, 3, 1024, 1024]},
output_path="s3://output",
role="role",
framework="tensorflow",
job_name="compile-model",
)

assert model.sagemaker_session == session.return_value


def test_check_neo_region(sagemaker_session, tmpdir):
sagemaker_session.wait_for_compilation_job = Mock(
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
Expand Down