|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 |
| -import pytest |
| 15 | +import datetime |
16 | 16 | import io
|
| 17 | +import logging |
| 18 | + |
| 19 | +import pytest |
17 | 20 | import six
|
| 21 | +from botocore.exceptions import ClientError |
18 | 22 | from mock import Mock, patch, call
|
| 23 | + |
19 | 24 | import sagemaker
|
20 | 25 | from sagemaker import s3_input, Session, get_execution_role
|
21 |
| -import datetime |
22 |
| - |
23 |
| -from botocore.exceptions import ClientError |
24 |
| - |
25 | 26 | from sagemaker.session import _tuning_job_status, _transform_job_status
|
26 | 27 |
|
27 | 28 | REGION = 'us-west-2'
|
@@ -515,18 +516,57 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
|
515 | 516 | call(0, 'hi there #2a'), call(0, 'hi there #3')]
|
516 | 517 |
|
517 | 518 |
|
| 519 | +MODEL_NAME = 'some-model' |
| 520 | +PRIMARY_CONTAINER = { |
| 521 | + 'Environment': {}, |
| 522 | + 'Image': IMAGE, |
| 523 | + 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz', |
| 524 | +} |
| 525 | + |
| 526 | + |
| 527 | +@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) |
| 528 | +def test_create_model(expand_container_def, sagemaker_session): |
| 529 | + model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) |
| 530 | + |
| 531 | + assert model == MODEL_NAME |
| 532 | + sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, |
| 533 | + ModelName=MODEL_NAME, |
| 534 | + PrimaryContainer=PRIMARY_CONTAINER) |
| 535 | + |
| 536 | + |
| 537 | +@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) |
| 538 | +def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog): |
| 539 | + error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing model'}} |
| 540 | + exception = ClientError(error_response, 'Operation') |
| 541 | + sagemaker_session.sagemaker_client.create_model.side_effect = exception |
| 542 | + |
| 543 | + model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) |
| 544 | + assert model == MODEL_NAME |
| 545 | + |
| 546 | + expected_warning = ('sagemaker', logging.WARNING, 'Using already existing model: {}'.format(MODEL_NAME)) |
| 547 | + assert expected_warning in caplog.record_tuples |
| 548 | + |
| 549 | + |
| 550 | +@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER) |
| 551 | +def test_create_model_failure(expand_container_def, sagemaker_session): |
| 552 | + error_message = 'this is expected' |
| 553 | + sagemaker_session.sagemaker_client.create_model.side_effect = RuntimeError(error_message) |
| 554 | + |
| 555 | + with pytest.raises(RuntimeError) as e: |
| 556 | + sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER) |
| 557 | + |
| 558 | + assert error_message in str(e) |
| 559 | + |
| 560 | + |
518 | 561 | def test_create_model_from_job(sagemaker_session):
|
519 | 562 | ims = sagemaker_session
|
520 | 563 | ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
|
521 | 564 | ims.create_model_from_job(JOB_NAME)
|
522 | 565 |
|
523 |
| - assert call(TrainingJobName='jobname') in ims.sagemaker_client.describe_training_job.call_args_list |
524 |
| - ims.sagemaker_client.create_model.assert_called_with( |
525 |
| - ExecutionRoleArn='arn:aws:iam::111111111111:role/ExpandedRole', |
526 |
| - ModelName='jobname', |
527 |
| - PrimaryContainer={ |
528 |
| - 'Environment': {}, 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz', |
529 |
| - 'Image': 'myimage'}) |
| 566 | + assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list |
| 567 | + ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE, |
| 568 | + ModelName=JOB_NAME, |
| 569 | + PrimaryContainer=PRIMARY_CONTAINER) |
530 | 570 |
|
531 | 571 |
|
532 | 572 | def test_create_model_from_job_with_image(sagemaker_session):
|
@@ -605,7 +645,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
|
605 | 645 | Tags=tags)
|
606 | 646 |
|
607 | 647 |
|
608 |
| -def test_wait_for_tuning_job(sagemaker_session): |
| 648 | +@patch('time.sleep') |
| 649 | +def test_wait_for_tuning_job(sleep, sagemaker_session): |
609 | 650 | hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}
|
610 | 651 | sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
|
611 | 652 | name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc)
|
@@ -634,15 +675,17 @@ def test_tune_job_status_none(sagemaker_session):
|
634 | 675 | assert result is None
|
635 | 676 |
|
636 | 677 |
|
637 |
| -def test_wait_for_transform_job_completed(sagemaker_session): |
| 678 | +@patch('time.sleep') |
| 679 | +def test_wait_for_transform_job_completed(sleep, sagemaker_session): |
638 | 680 | transform_job_desc = {'TransformJobStatus': 'Completed'}
|
639 | 681 | sagemaker_session.sagemaker_client.describe_transform_job = Mock(
|
640 | 682 | name='describe_transform_job', return_value=transform_job_desc)
|
641 | 683 |
|
642 | 684 | assert sagemaker_session.wait_for_transform_job(JOB_NAME)['TransformJobStatus'] == 'Completed'
|
643 | 685 |
|
644 | 686 |
|
645 |
| -def test_wait_for_transform_job_in_progress(sagemaker_session): |
| 687 | +@patch('time.sleep') |
| 688 | +def test_wait_for_transform_job_in_progress(sleep, sagemaker_session): |
646 | 689 | transform_job_desc_in_progress = {'TransformJobStatus': 'InProgress'}
|
647 | 690 | transform_job_desc_in_completed = {'TransformJobStatus': 'Completed'}
|
648 | 691 | sagemaker_session.sagemaker_client.describe_transform_job = Mock(
|
|
0 commit comments