|
27 | 27 | BAD_RESPONSE = urllib3.HTTPResponse()
|
28 | 28 | BAD_RESPONSE.status = 502
|
29 | 29 |
|
| 30 | +ENDPOINT_CONFIG_NAME = 'test-endpoint-config' |
| 31 | +PRODUCTION_VARIANTS = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}] |
| 32 | + |
| 33 | +MODEL_NAME = 'test-model' |
| 34 | +PRIMARY_CONTAINER = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}} |
| 35 | + |
30 | 36 |
|
31 | 37 | @patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
|
32 | 38 | @patch('sagemaker.local.local_session.LocalSession')
|
@@ -148,25 +154,32 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
|
148 | 154 | @patch('sagemaker.local.local_session.LocalSession')
|
149 | 155 | def test_create_model(LocalSession):
|
150 | 156 | local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
|
151 |
| - model_name = 'my-model' |
152 |
| - primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}} |
153 | 157 |
|
154 |
| - local_sagemaker_client.create_model(model_name, primary_container) |
| 158 | + local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) |
| 159 | + |
| 160 | + assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models |
| 161 | + |
| 162 | + |
| 163 | +@patch('sagemaker.local.local_session.LocalSession') |
| 164 | +def test_delete_model(LocalSession): |
| 165 | + local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() |
| 166 | + |
| 167 | + local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) |
| 168 | + assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models |
155 | 169 |
|
156 |
| - assert 'my-model' in sagemaker.local.local_session.LocalSagemakerClient._models |
| 170 | + local_sagemaker_client.delete_model(MODEL_NAME) |
| 171 | + assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models |
157 | 172 |
|
158 | 173 |
|
159 | 174 | @patch('sagemaker.local.local_session.LocalSession')
|
160 | 175 | def test_describe_model(LocalSession):
|
161 | 176 | local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
|
162 |
| - model_name = 'test-model' |
163 |
| - primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}} |
164 | 177 |
|
165 | 178 | with pytest.raises(ClientError):
|
166 | 179 | local_sagemaker_client.describe_model('model-does-not-exist')
|
167 | 180 |
|
168 |
| - local_sagemaker_client.create_model(model_name, primary_container) |
169 |
| - response = local_sagemaker_client.describe_model('test-model') |
| 181 | + local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) |
| 182 | + response = local_sagemaker_client.describe_model(MODEL_NAME) |
170 | 183 |
|
171 | 184 | assert response['ModelName'] == 'test-model'
|
172 | 185 | assert response['PrimaryContainer']['ModelDataUrl'] == '/some/model/path'
|
@@ -212,10 +225,20 @@ def test_describe_endpoint_config(LocalSession):
|
212 | 225 | @patch('sagemaker.local.local_session.LocalSession')
|
213 | 226 | def test_create_endpoint_config(LocalSession):
|
214 | 227 | local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
|
215 |
| - production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}] |
216 |
| - local_sagemaker_client.create_endpoint_config('my-endpoint-config', production_variants) |
| 228 | + local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) |
| 229 | + |
| 230 | + assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs |
| 231 | + |
| 232 | + |
| 233 | +@patch('sagemaker.local.local_session.LocalSession') |
| 234 | +def test_delete_endpoint_config(LocalSession): |
| 235 | + local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() |
| 236 | + |
| 237 | + local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) |
| 238 | + assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs |
217 | 239 |
|
218 |
| - assert 'my-endpoint-config' in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs |
| 240 | + local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME) |
| 241 | + assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs |
219 | 242 |
|
220 | 243 |
|
221 | 244 | @patch('sagemaker.local.image._SageMakerContainer.serve')
|
@@ -316,7 +339,7 @@ def test_update_endpoint(LocalSession):
|
316 | 339 | endpoint_name = 'my-endpoint'
|
317 | 340 | endpoint_config = 'my-endpoint-config'
|
318 | 341 | expected_error_message = 'Update endpoint name is not supported in local session.'
|
319 |
| - with pytest.raises(NotImplementedError, message=expected_error_message): |
| 342 | + with pytest.raises(NotImplementedError, match=expected_error_message): |
320 | 343 | local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config)
|
321 | 344 |
|
322 | 345 |
|
|
0 commit comments