|
15 | 15 | import pytest
|
16 | 16 | from mock import MagicMock, Mock, patch
|
17 | 17 |
|
18 |
| -from sagemaker.transformer import Transformer, _TransformJob |
| 18 | +from sagemaker.transformer import _TransformJob, Transformer |
19 | 19 | from tests.integ import test_local_mode
|
20 | 20 |
|
21 | 21 | MODEL_NAME = 'model'
|
|
40 | 40 | 'base_transform_job_name': JOB_NAME
|
41 | 41 | }
|
42 | 42 |
|
| 43 | +MODEL_DESC_PRIMARY_CONTAINER = { |
| 44 | + 'PrimaryContainer': { |
| 45 | + 'Image': IMAGE_NAME |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +MODEL_DESC_CONTAINERS_ONLY = { |
| 50 | + 'Containers': [ |
| 51 | + {'Image': IMAGE_NAME} |
| 52 | + ] |
| 53 | +} |
| 54 | + |
43 | 55 |
|
44 | 56 | @pytest.fixture(autouse=True)
|
45 | 57 | def mock_create_tar_file():
|
@@ -97,29 +109,80 @@ def test_transform_with_all_params(start_new_job, transformer):
|
97 | 109 |
|
98 | 110 | @patch('sagemaker.transformer.name_from_base')
|
99 | 111 | @patch('sagemaker.transformer._TransformJob.start_new')
|
100 |
| -def test_transform_with_base_job_name(start_new_job, name_from_base, transformer): |
| 112 | +def test_transform_with_base_job_name_provided(start_new_job, name_from_base, transformer): |
101 | 113 | base_name = 'base-job-name'
|
102 | 114 | full_name = '{}-{}'.format(base_name, TIMESTAMP)
|
103 | 115 |
|
104 | 116 | transformer.base_transform_job_name = base_name
|
105 | 117 | name_from_base.return_value = full_name
|
106 | 118 |
|
107 | 119 | transformer.transform(DATA)
|
108 |
| - assert name_from_base.called_with(base_name) |
| 120 | + |
| 121 | + name_from_base.assert_called_once_with(base_name) |
| 122 | + assert transformer._current_job_name == full_name |
| 123 | + |
| 124 | + |
| 125 | +@patch('sagemaker.transformer.Transformer._retrieve_base_name', return_value=IMAGE_NAME) |
| 126 | +@patch('sagemaker.transformer.name_from_base') |
| 127 | +@patch('sagemaker.transformer._TransformJob.start_new') |
| 128 | +def test_transform_with_base_name(start_new_job, name_from_base, retrieve_base_name, transformer): |
| 129 | + full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) |
| 130 | + name_from_base.return_value = full_name |
| 131 | + |
| 132 | + transformer.transform(DATA) |
| 133 | + |
| 134 | + retrieve_base_name.assert_called_once_with() |
| 135 | + name_from_base.assert_called_once_with(IMAGE_NAME) |
109 | 136 | assert transformer._current_job_name == full_name
|
110 | 137 |
|
111 | 138 |
|
112 | 139 | @patch('sagemaker.transformer.Transformer._retrieve_image_name', return_value=IMAGE_NAME)
|
113 | 140 | @patch('sagemaker.transformer.name_from_base')
|
114 | 141 | @patch('sagemaker.transformer._TransformJob.start_new')
|
115 |
| -def test_transform_with_fully_generated_job_name(start_new_job, name_from_base, retrieve_image_name, transformer): |
| 142 | +def test_transform_with_job_name_based_on_image(start_new_job, name_from_base, retrieve_image_name, transformer): |
116 | 143 | full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
|
117 | 144 | name_from_base.return_value = full_name
|
118 | 145 |
|
119 | 146 | transformer.transform(DATA)
|
120 | 147 |
|
121 |
| - assert retrieve_image_name.called_once |
122 |
| - assert name_from_base.called_with(IMAGE_NAME) |
| 148 | + retrieve_image_name.assert_called_once_with() |
| 149 | + name_from_base.assert_called_once_with(IMAGE_NAME) |
| 150 | + assert transformer._current_job_name == full_name |
| 151 | + |
| 152 | + |
| 153 | +@pytest.mark.parametrize('model_desc', [MODEL_DESC_PRIMARY_CONTAINER, |
| 154 | + MODEL_DESC_CONTAINERS_ONLY]) |
| 155 | +@patch('sagemaker.transformer.name_from_base') |
| 156 | +@patch('sagemaker.transformer._TransformJob.start_new') |
| 157 | +def test_transform_with_job_name_based_on_containers(start_new_job, name_from_base, model_desc, transformer): |
| 158 | + transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc |
| 159 | + |
| 160 | + full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP) |
| 161 | + name_from_base.return_value = full_name |
| 162 | + |
| 163 | + transformer.transform(DATA) |
| 164 | + |
| 165 | + transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME) |
| 166 | + name_from_base.assert_called_once_with(IMAGE_NAME) |
| 167 | + assert transformer._current_job_name == full_name |
| 168 | + |
| 169 | + |
| 170 | +@pytest.mark.parametrize('model_desc', [{'PrimaryContainer': dict()}, |
| 171 | + {'Containers': [dict()]}, |
| 172 | + dict(), |
| 173 | + ]) |
| 174 | +@patch('sagemaker.transformer.name_from_base') |
| 175 | +@patch('sagemaker.transformer._TransformJob.start_new') |
| 176 | +def test_transform_with_job_name_based_on_model_name(start_new_job, name_from_base, model_desc, transformer): |
| 177 | + transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc |
| 178 | + |
| 179 | + full_name = '{}-{}'.format(MODEL_NAME, TIMESTAMP) |
| 180 | + name_from_base.return_value = full_name |
| 181 | + |
| 182 | + transformer.transform(DATA) |
| 183 | + |
| 184 | + transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME) |
| 185 | + name_from_base.assert_called_once_with(MODEL_NAME) |
123 | 186 | assert transformer._current_job_name == full_name
|
124 | 187 |
|
125 | 188 |
|
|
0 commit comments