Skip to content

Commit cb4aa9d

Browse files
authored
Use existing model instead of failing during Session.create_model() (#306)
Currently an exception is raised if a model already exists. This change instead logs a warning that an existing model is being used, but allows the existing model to be used in place of needing to create a new one.
1 parent 704cd31 commit cb4aa9d

File tree

3 files changed

+70
-16
lines changed

3 files changed

+70
-16
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.7.1dev
6+
========
7+
8+
* bug-fix: Session: use existing model instead of failing during ``create_model()``
9+
510
1.7.0
611
=====
712

src/sagemaker/session.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,15 +439,24 @@ def create_model(self, name, role, primary_container):
439439
role = self.expand_role(role)
440440
primary_container = _expand_container_def(primary_container)
441441
LOGGER.info('Creating model with name: {}'.format(name))
442-
LOGGER.debug("create_model request: {}".format({
442+
LOGGER.debug('create_model request: {}'.format({
443443
'name': name,
444444
'role': role,
445445
'primary_container': primary_container
446446
}))
447447

448-
self.sagemaker_client.create_model(ModelName=name,
449-
PrimaryContainer=primary_container,
450-
ExecutionRoleArn=role)
448+
try:
449+
self.sagemaker_client.create_model(ModelName=name,
450+
PrimaryContainer=primary_container,
451+
ExecutionRoleArn=role)
452+
except ClientError as e:
453+
error_code = e.response['Error']['Code']
454+
message = e.response['Error']['Message']
455+
456+
if error_code == 'ValidationException' and 'Cannot create already existing model' in message:
457+
LOGGER.warning('Using already existing model: {}'.format(name))
458+
else:
459+
raise
451460

452461
return name
453462

tests/unit/test_session.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import pytest
15+
import datetime
1616
import io
17+
import logging
18+
19+
import pytest
1720
import six
21+
from botocore.exceptions import ClientError
1822
from mock import Mock, patch, call
23+
1924
import sagemaker
2025
from sagemaker import s3_input, Session, get_execution_role
21-
import datetime
22-
23-
from botocore.exceptions import ClientError
24-
2526
from sagemaker.session import _tuning_job_status, _transform_job_status
2627

2728
REGION = 'us-west-2'
@@ -502,18 +503,57 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
502503
call(0, 'hi there #2a'), call(0, 'hi there #3')]
503504

504505

506+
MODEL_NAME = 'some-model'
507+
PRIMARY_CONTAINER = {
508+
'Environment': {},
509+
'Image': IMAGE,
510+
'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
511+
}
512+
513+
514+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
515+
def test_create_model(expand_container_def, sagemaker_session):
516+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
517+
518+
assert model == MODEL_NAME
519+
sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
520+
ModelName=MODEL_NAME,
521+
PrimaryContainer=PRIMARY_CONTAINER)
522+
523+
524+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
525+
def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog):
526+
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing model'}}
527+
exception = ClientError(error_response, 'Operation')
528+
sagemaker_session.sagemaker_client.create_model.side_effect = exception
529+
530+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
531+
assert model == MODEL_NAME
532+
533+
expected_warning = ('sagemaker', logging.WARNING, 'Using already existing model: {}'.format(MODEL_NAME))
534+
assert expected_warning in caplog.record_tuples
535+
536+
537+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
538+
def test_create_model_failure(expand_container_def, sagemaker_session):
539+
error_message = 'this is expected'
540+
sagemaker_session.sagemaker_client.create_model.side_effect = RuntimeError(error_message)
541+
542+
with pytest.raises(RuntimeError) as e:
543+
sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
544+
545+
assert error_message in str(e)
546+
547+
505548
def test_create_model_from_job(sagemaker_session):
506549
ims = sagemaker_session
507550
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
508551
ims.create_model_from_job(JOB_NAME)
509552

510-
assert call(TrainingJobName='jobname') in ims.sagemaker_client.describe_training_job.call_args_list
511-
ims.sagemaker_client.create_model.assert_called_with(
512-
ExecutionRoleArn='arn:aws:iam::111111111111:role/ExpandedRole',
513-
ModelName='jobname',
514-
PrimaryContainer={
515-
'Environment': {}, 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
516-
'Image': 'myimage'})
553+
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
554+
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
555+
ModelName=JOB_NAME,
556+
PrimaryContainer=PRIMARY_CONTAINER)
517557

518558

519559
def test_create_model_from_job_with_image(sagemaker_session):

0 commit comments

Comments
 (0)