Skip to content

Commit 3896030

Browse files
committed
add docstring and fix unit test
1 parent 7cdf9ea commit 3896030

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/sagemaker/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def __init__(
116116
self.model_kms_key = model_kms_key
117117

118118
def _init_sagemaker_session_if_does_not_exist(self, instance_type):
119+
"""Set ``self.sagemaker_session`` to be a ``LocalSession`` or
120+
``Session`` if it is not already. The type of session object is
121+
determined by the instance type.
122+
"""
119123
if self.sagemaker_session:
120124
return
121125

tests/unit/test_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -579,16 +579,19 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
579579
@patch("sagemaker.session.Session")
580580
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
581581
def test_compile_creates_session(session):
582+
session.return_value.boto_region_name = "us-west-2"
583+
582584
model = DummyFrameworkModel(sagemaker_session=None)
583585
model.compile(
584586
target_instance_family="ml_c4",
585587
input_shape={"data": [1, 3, 1024, 1024]},
586588
output_path="s3://output",
587589
role="role",
588590
framework="tensorflow",
591+
job_name="compile-model",
589592
)
590593

591-
assert model.sagemaker_sesion == session.return_value
594+
assert model.sagemaker_session == session.return_value
592595

593596

594597
def test_check_neo_region(sagemaker_session, tmpdir):

0 commit comments

Comments
 (0)