diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 625381a942..24b61477d1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,11 @@ CHANGELOG ========= +1.7.1dev +======== + +* bug-fix: Session: use existing model instead of failing during ``create_model()`` + 1.7.0 ===== diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d08c4fc307..20b802188f 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -439,15 +439,24 @@ def create_model(self, name, role, primary_container): role = self.expand_role(role) primary_container = _expand_container_def(primary_container) LOGGER.info('Creating model with name: {}'.format(name)) - LOGGER.debug("create_model request: {}".format({ + LOGGER.debug('create_model request: {}'.format({ 'name': name, 'role': role, 'primary_container': primary_container })) - self.sagemaker_client.create_model(ModelName=name, - PrimaryContainer=primary_container, - ExecutionRoleArn=role) + try: + self.sagemaker_client.create_model(ModelName=name, + PrimaryContainer=primary_container, + ExecutionRoleArn=role) + except ClientError as e: + error_code = e.response['Error']['Code'] + message = e.response['Error']['Message'] + + if error_code == 'ValidationException' and 'Cannot create already existing model' in message: + LOGGER.warning('Using already existing model: {}'.format(name)) + else: + raise return name diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 00a6fd82be..e4373ab4dd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -12,16 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import pytest +import datetime import io +import logging + +import pytest import six +from botocore.exceptions import ClientError from mock import Mock, patch, call + import sagemaker from sagemaker import s3_input, Session, get_execution_role -import datetime - -from botocore.exceptions import ClientError - from sagemaker.session import _tuning_job_status, _transform_job_status REGION = 'us-west-2' @@ -502,18 +503,57 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle) call(0, 'hi there #2a'), call(0, 'hi there #3')] +MODEL_NAME = 'some-model' +PRIMARY_CONTAINER = { + 'Environment': {}, + 'Image': IMAGE, + 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz', +} + + +@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +def test_create_model(expand_container_def, sagemaker_session): + model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) + + assert model == MODEL_NAME + sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, + ModelName=MODEL_NAME, + PrimaryContainer=PRIMARY_CONTAINER) + + +@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog): + error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing model'}} + exception = ClientError(error_response, 'Operation') + sagemaker_session.sagemaker_client.create_model.side_effect = exception + + model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) + assert model == MODEL_NAME + + expected_warning = ('sagemaker', logging.WARNING, 'Using already existing model: {}'.format(MODEL_NAME)) + assert expected_warning in caplog.record_tuples + + +@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) +def test_create_model_failure(expand_container_def, sagemaker_session): + error_message = 'this is expected' + sagemaker_session.sagemaker_client.create_model.side_effect = RuntimeError(error_message) + + with pytest.raises(RuntimeError) as e: + sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) + + assert error_message in str(e) + + def test_create_model_from_job(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT ims.create_model_from_job(JOB_NAME) - assert call(TrainingJobName='jobname') in ims.sagemaker_client.describe_training_job.call_args_list - ims.sagemaker_client.create_model.assert_called_with( - ExecutionRoleArn='arn:aws:iam::111111111111:role/ExpandedRole', - ModelName='jobname', - PrimaryContainer={ - 'Environment': {}, 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz', - 'Image': 'myimage'}) + assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list + ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, + ModelName=JOB_NAME, + PrimaryContainer=PRIMARY_CONTAINER) def test_create_model_from_job_with_image(sagemaker_session):