Skip to content

Commit b6247a3

Browse files
author
Dan
authored
fix: add better default transform job name handling within Transformer (#822)
1 parent 1277136 commit b6247a3

File tree

2 files changed

+93
-8
lines changed

2 files changed

+93
-8
lines changed

src/sagemaker/transformer.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
105105
if job_name is not None:
106106
self._current_job_name = job_name
107107
else:
108-
base_name = self.base_transform_job_name or base_name_from_image(self._retrieve_image_name())
108+
base_name = self.base_transform_job_name
109+
110+
if base_name is None:
111+
base_name = self._retrieve_base_name()
112+
109113
self._current_job_name = name_from_base(base_name)
110114

111115
if self.output_path is None:
@@ -120,10 +124,28 @@ def delete_model(self):
120124
"""
121125
self.sagemaker_session.delete_model(self.model_name)
122126

127+
def _retrieve_base_name(self):
128+
image_name = self._retrieve_image_name()
129+
130+
if image_name:
131+
return base_name_from_image(image_name)
132+
133+
return self.model_name
134+
123135
def _retrieve_image_name(self):
124136
try:
125137
model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name)
126-
return model_desc['PrimaryContainer']['Image']
138+
139+
primary_container = model_desc.get('PrimaryContainer')
140+
if primary_container:
141+
return primary_container.get('Image')
142+
143+
containers = model_desc.get('Containers')
144+
if containers:
145+
return containers[0].get('Image')
146+
147+
return None
148+
127149
except exceptions.ClientError:
128150
raise ValueError('Failed to fetch model information for %s. '
129151
'Please ensure that the model exists. '

tests/unit/test_transformer.py

+69-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
from mock import MagicMock, Mock, patch
1717

18-
from sagemaker.transformer import Transformer, _TransformJob
18+
from sagemaker.transformer import _TransformJob, Transformer
1919
from tests.integ import test_local_mode
2020

2121
MODEL_NAME = 'model'
@@ -40,6 +40,18 @@
4040
'base_transform_job_name': JOB_NAME
4141
}
4242

43+
MODEL_DESC_PRIMARY_CONTAINER = {
44+
'PrimaryContainer': {
45+
'Image': IMAGE_NAME
46+
}
47+
}
48+
49+
MODEL_DESC_CONTAINERS_ONLY = {
50+
'Containers': [
51+
{'Image': IMAGE_NAME}
52+
]
53+
}
54+
4355

4456
@pytest.fixture(autouse=True)
4557
def mock_create_tar_file():
@@ -97,29 +109,80 @@ def test_transform_with_all_params(start_new_job, transformer):
97109

98110
@patch('sagemaker.transformer.name_from_base')
99111
@patch('sagemaker.transformer._TransformJob.start_new')
100-
def test_transform_with_base_job_name(start_new_job, name_from_base, transformer):
112+
def test_transform_with_base_job_name_provided(start_new_job, name_from_base, transformer):
101113
base_name = 'base-job-name'
102114
full_name = '{}-{}'.format(base_name, TIMESTAMP)
103115

104116
transformer.base_transform_job_name = base_name
105117
name_from_base.return_value = full_name
106118

107119
transformer.transform(DATA)
108-
assert name_from_base.called_with(base_name)
120+
121+
name_from_base.assert_called_once_with(base_name)
122+
assert transformer._current_job_name == full_name
123+
124+
125+
@patch('sagemaker.transformer.Transformer._retrieve_base_name', return_value=IMAGE_NAME)
126+
@patch('sagemaker.transformer.name_from_base')
127+
@patch('sagemaker.transformer._TransformJob.start_new')
128+
def test_transform_with_base_name(start_new_job, name_from_base, retrieve_base_name, transformer):
129+
full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
130+
name_from_base.return_value = full_name
131+
132+
transformer.transform(DATA)
133+
134+
retrieve_base_name.assert_called_once_with()
135+
name_from_base.assert_called_once_with(IMAGE_NAME)
109136
assert transformer._current_job_name == full_name
110137

111138

112139
@patch('sagemaker.transformer.Transformer._retrieve_image_name', return_value=IMAGE_NAME)
113140
@patch('sagemaker.transformer.name_from_base')
114141
@patch('sagemaker.transformer._TransformJob.start_new')
115-
def test_transform_with_fully_generated_job_name(start_new_job, name_from_base, retrieve_image_name, transformer):
142+
def test_transform_with_job_name_based_on_image(start_new_job, name_from_base, retrieve_image_name, transformer):
116143
full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
117144
name_from_base.return_value = full_name
118145

119146
transformer.transform(DATA)
120147

121-
assert retrieve_image_name.called_once
122-
assert name_from_base.called_with(IMAGE_NAME)
148+
retrieve_image_name.assert_called_once_with()
149+
name_from_base.assert_called_once_with(IMAGE_NAME)
150+
assert transformer._current_job_name == full_name
151+
152+
153+
@pytest.mark.parametrize('model_desc', [MODEL_DESC_PRIMARY_CONTAINER,
154+
MODEL_DESC_CONTAINERS_ONLY])
155+
@patch('sagemaker.transformer.name_from_base')
156+
@patch('sagemaker.transformer._TransformJob.start_new')
157+
def test_transform_with_job_name_based_on_containers(start_new_job, name_from_base, model_desc, transformer):
158+
transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc
159+
160+
full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
161+
name_from_base.return_value = full_name
162+
163+
transformer.transform(DATA)
164+
165+
transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME)
166+
name_from_base.assert_called_once_with(IMAGE_NAME)
167+
assert transformer._current_job_name == full_name
168+
169+
170+
@pytest.mark.parametrize('model_desc', [{'PrimaryContainer': dict()},
171+
{'Containers': [dict()]},
172+
dict(),
173+
])
174+
@patch('sagemaker.transformer.name_from_base')
175+
@patch('sagemaker.transformer._TransformJob.start_new')
176+
def test_transform_with_job_name_based_on_model_name(start_new_job, name_from_base, model_desc, transformer):
177+
transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc
178+
179+
full_name = '{}-{}'.format(MODEL_NAME, TIMESTAMP)
180+
name_from_base.return_value = full_name
181+
182+
transformer.transform(DATA)
183+
184+
transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME)
185+
name_from_base.assert_called_once_with(MODEL_NAME)
123186
assert transformer._current_job_name == full_name
124187

125188

0 commit comments

Comments
 (0)