Skip to content

Commit 96da497

Browse files
authored
Merge branch 'master' into get-execution-role-fix
2 parents 7cd5c6f + 6ff1e23 commit 96da497

File tree

4 files changed

+77
-23
lines changed

4 files changed

+77
-23
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
=====
77

88
* bug-fix: get_execution_role no longer fails if user can't call get_role
9+
* bug-fix: Session: use existing model instead of failing during ``create_model()``
910

1011
1.7.0
1112
=====

src/sagemaker/session.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,15 +439,24 @@ def create_model(self, name, role, primary_container):
439439
role = self.expand_role(role)
440440
primary_container = _expand_container_def(primary_container)
441441
LOGGER.info('Creating model with name: {}'.format(name))
442-
LOGGER.debug("create_model request: {}".format({
442+
LOGGER.debug('create_model request: {}'.format({
443443
'name': name,
444444
'role': role,
445445
'primary_container': primary_container
446446
}))
447447

448-
self.sagemaker_client.create_model(ModelName=name,
449-
PrimaryContainer=primary_container,
450-
ExecutionRoleArn=role)
448+
try:
449+
self.sagemaker_client.create_model(ModelName=name,
450+
PrimaryContainer=primary_container,
451+
ExecutionRoleArn=role)
452+
except ClientError as e:
453+
error_code = e.response['Error']['Code']
454+
message = e.response['Error']['Message']
455+
456+
if error_code == 'ValidationException' and 'Cannot create already existing model' in message:
457+
LOGGER.warning('Using already existing model: {}'.format(name))
458+
else:
459+
raise
451460

452461
return name
453462

tests/unit/test_session.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import pytest
15+
import datetime
1616
import io
17+
import logging
18+
19+
import pytest
1720
import six
21+
from botocore.exceptions import ClientError
1822
from mock import Mock, patch, call
23+
1924
import sagemaker
2025
from sagemaker import s3_input, Session, get_execution_role
21-
import datetime
22-
23-
from botocore.exceptions import ClientError
24-
2526
from sagemaker.session import _tuning_job_status, _transform_job_status
2627

2728
REGION = 'us-west-2'
@@ -515,18 +516,57 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
515516
call(0, 'hi there #2a'), call(0, 'hi there #3')]
516517

517518

519+
MODEL_NAME = 'some-model'
520+
PRIMARY_CONTAINER = {
521+
'Environment': {},
522+
'Image': IMAGE,
523+
'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
524+
}
525+
526+
527+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
528+
def test_create_model(expand_container_def, sagemaker_session):
529+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
530+
531+
assert model == MODEL_NAME
532+
sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
533+
ModelName=MODEL_NAME,
534+
PrimaryContainer=PRIMARY_CONTAINER)
535+
536+
537+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
538+
def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog):
539+
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing model'}}
540+
exception = ClientError(error_response, 'Operation')
541+
sagemaker_session.sagemaker_client.create_model.side_effect = exception
542+
543+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
544+
assert model == MODEL_NAME
545+
546+
expected_warning = ('sagemaker', logging.WARNING, 'Using already existing model: {}'.format(MODEL_NAME))
547+
assert expected_warning in caplog.record_tuples
548+
549+
550+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
551+
def test_create_model_failure(expand_container_def, sagemaker_session):
552+
error_message = 'this is expected'
553+
sagemaker_session.sagemaker_client.create_model.side_effect = RuntimeError(error_message)
554+
555+
with pytest.raises(RuntimeError) as e:
556+
sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
557+
558+
assert error_message in str(e)
559+
560+
518561
def test_create_model_from_job(sagemaker_session):
519562
ims = sagemaker_session
520563
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
521564
ims.create_model_from_job(JOB_NAME)
522565

523-
assert call(TrainingJobName='jobname') in ims.sagemaker_client.describe_training_job.call_args_list
524-
ims.sagemaker_client.create_model.assert_called_with(
525-
ExecutionRoleArn='arn:aws:iam::111111111111:role/ExpandedRole',
526-
ModelName='jobname',
527-
PrimaryContainer={
528-
'Environment': {}, 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
529-
'Image': 'myimage'})
566+
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
567+
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
568+
ModelName=JOB_NAME,
569+
PrimaryContainer=PRIMARY_CONTAINER)
530570

531571

532572
def test_create_model_from_job_with_image(sagemaker_session):
@@ -605,7 +645,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
605645
Tags=tags)
606646

607647

608-
def test_wait_for_tuning_job(sagemaker_session):
648+
@patch('time.sleep')
649+
def test_wait_for_tuning_job(sleep, sagemaker_session):
609650
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}
610651
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
611652
name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc)
@@ -634,15 +675,17 @@ def test_tune_job_status_none(sagemaker_session):
634675
assert result is None
635676

636677

637-
def test_wait_for_transform_job_completed(sagemaker_session):
678+
@patch('time.sleep')
679+
def test_wait_for_transform_job_completed(sleep, sagemaker_session):
638680
transform_job_desc = {'TransformJobStatus': 'Completed'}
639681
sagemaker_session.sagemaker_client.describe_transform_job = Mock(
640682
name='describe_transform_job', return_value=transform_job_desc)
641683

642684
assert sagemaker_session.wait_for_transform_job(JOB_NAME)['TransformJobStatus'] == 'Completed'
643685

644686

645-
def test_wait_for_transform_job_in_progress(sagemaker_session):
687+
@patch('time.sleep')
688+
def test_wait_for_transform_job_in_progress(sleep, sagemaker_session):
646689
transform_job_desc_in_progress = {'TransformJobStatus': 'InProgress'}
647690
transform_job_desc_in_completed = {'TransformJobStatus': 'Completed'}
648691
sagemaker_session.sagemaker_client.describe_transform_job = Mock(

tests/unit/test_tf_estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,8 @@ def test_run_tensorboard_locally_without_awscli_binary(time, strftime, popen, ca
312312
@patch('subprocess.Popen')
313313
@patch('time.strftime', return_value=TIMESTAMP)
314314
@patch('time.time', return_value=TIME)
315-
def test_run_tensorboard_locally(time, strftime, popen, call, access, rmtree, mkdtemp, sync, sagemaker_session):
315+
@patch('time.sleep')
316+
def test_run_tensorboard_locally(sleep, time, strftime, popen, call, access, rmtree, mkdtemp, sync, sagemaker_session):
316317
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
317318
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
318319

@@ -322,8 +323,7 @@ def test_run_tensorboard_locally(time, strftime, popen, call, access, rmtree, mk
322323

323324
popen.assert_called_with(['tensorboard', '--logdir', '/my/temp/folder', '--host', 'localhost', '--port', '6006'],
324325
stderr=-1,
325-
stdout=-1
326-
)
326+
stdout=-1)
327327

328328

329329
@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories')
@@ -335,7 +335,8 @@ def test_run_tensorboard_locally(time, strftime, popen, call, access, rmtree, mk
335335
@patch('subprocess.Popen')
336336
@patch('time.strftime', return_value=TIMESTAMP)
337337
@patch('time.time', return_value=TIME)
338-
def test_run_tensorboard_locally_port_in_use(time, strftime, popen, call, access, socket, rmtree, mkdtemp, sync,
338+
@patch('time.sleep')
339+
def test_run_tensorboard_locally_port_in_use(sleep, time, strftime, popen, call, access, socket, rmtree, mkdtemp, sync,
339340
sagemaker_session):
340341
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
341342
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)

0 commit comments

Comments
 (0)