|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 |
| -import copy |
16 | 15 | import os
|
17 | 16 | import subprocess
|
18 | 17 |
|
19 |
| -import sagemaker |
20 |
| -from sagemaker.model import FrameworkModel, Model, ModelPackage |
| 18 | +from sagemaker.model import FrameworkModel |
21 | 19 | from sagemaker.predictor import RealTimePredictor
|
22 | 20 |
|
23 | 21 | import pytest
|
|
53 | 51 | CODECOMMIT_BRANCH = "master"
|
54 | 52 | REPO_DIR = "/tmp/repo_dir"
|
55 | 53 |
|
56 |
| - |
57 |
| -DESCRIBE_MODEL_PACKAGE_RESPONSE = { |
58 |
| - "InferenceSpecification": { |
59 |
| - "SupportedResponseMIMETypes": ["text"], |
60 |
| - "SupportedContentTypes": ["text/csv"], |
61 |
| - "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], |
62 |
| - "Containers": [ |
63 |
| - { |
64 |
| - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", |
65 |
| - "ImageDigest": "sha256:1234556789", |
66 |
| - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
67 |
| - } |
68 |
| - ], |
69 |
| - "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], |
70 |
| - }, |
71 |
| - "ModelPackageDescription": "Model Package created from training with " |
72 |
| - "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", |
73 |
| - "CreationTime": 1542752036.687, |
74 |
| - "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", |
75 |
| - "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, |
76 |
| - "SourceAlgorithmSpecification": { |
77 |
| - "SourceAlgorithms": [ |
78 |
| - { |
79 |
| - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
80 |
| - "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", |
81 |
| - } |
82 |
| - ] |
83 |
| - }, |
84 |
| - "ModelPackageStatus": "Completed", |
85 |
| - "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", |
86 |
| - "CertifyForMarketplace": False, |
87 |
| -} |
88 |
| - |
89 | 54 | DESCRIBE_COMPILATION_JOB_RESPONSE = {
|
90 | 55 | "CompilationJobStatus": "Completed",
|
91 | 56 | "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
|
@@ -417,181 +382,6 @@ def test_model_enable_network_isolation(sagemaker_session):
|
417 | 382 | assert model.enable_network_isolation() is False
|
418 | 383 |
|
419 | 384 |
|
420 |
| -@patch("sagemaker.model.Model._create_sagemaker_model") |
421 |
| -def test_model_create_transformer(create_sagemaker_model, sagemaker_session): |
422 |
| - model_name = "auto-generated-model" |
423 |
| - model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session) |
424 |
| - |
425 |
| - instance_type = "ml.m4.xlarge" |
426 |
| - transformer = model.transformer(instance_count=1, instance_type=instance_type) |
427 |
| - |
428 |
| - create_sagemaker_model.assert_called_with(instance_type, tags=None) |
429 |
| - |
430 |
| - assert isinstance(transformer, sagemaker.transformer.Transformer) |
431 |
| - assert transformer.model_name == model_name |
432 |
| - assert transformer.instance_type == instance_type |
433 |
| - assert transformer.instance_count == 1 |
434 |
| - assert transformer.sagemaker_session == sagemaker_session |
435 |
| - assert transformer.base_transform_job_name == model_name |
436 |
| - |
437 |
| - assert transformer.strategy is None |
438 |
| - assert transformer.env is None |
439 |
| - assert transformer.output_path is None |
440 |
| - assert transformer.output_kms_key is None |
441 |
| - assert transformer.accept is None |
442 |
| - assert transformer.assemble_with is None |
443 |
| - assert transformer.volume_kms_key is None |
444 |
| - assert transformer.max_concurrent_transforms is None |
445 |
| - assert transformer.max_payload is None |
446 |
| - assert transformer.tags is None |
447 |
| - |
448 |
| - |
449 |
| -@patch("sagemaker.model.Model._create_sagemaker_model") |
450 |
| -def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session): |
451 |
| - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) |
452 |
| - |
453 |
| - instance_type = "ml.m4.xlarge" |
454 |
| - strategy = "MultiRecord" |
455 |
| - assemble_with = "Line" |
456 |
| - output_path = "s3://bucket/path" |
457 |
| - kms_key = "key" |
458 |
| - accept = "text/csv" |
459 |
| - env = {"test": True} |
460 |
| - max_concurrent_transforms = 1 |
461 |
| - max_payload = 6 |
462 |
| - tags = [{"Key": "k", "Value": "v"}] |
463 |
| - |
464 |
| - transformer = model.transformer( |
465 |
| - instance_count=1, |
466 |
| - instance_type=instance_type, |
467 |
| - strategy=strategy, |
468 |
| - assemble_with=assemble_with, |
469 |
| - output_path=output_path, |
470 |
| - output_kms_key=kms_key, |
471 |
| - accept=accept, |
472 |
| - env=env, |
473 |
| - max_concurrent_transforms=max_concurrent_transforms, |
474 |
| - max_payload=max_payload, |
475 |
| - tags=tags, |
476 |
| - volume_kms_key=kms_key, |
477 |
| - ) |
478 |
| - |
479 |
| - create_sagemaker_model.assert_called_with(instance_type, tags=tags) |
480 |
| - |
481 |
| - assert isinstance(transformer, sagemaker.transformer.Transformer) |
482 |
| - assert transformer.strategy == strategy |
483 |
| - assert transformer.assemble_with == assemble_with |
484 |
| - assert transformer.output_path == output_path |
485 |
| - assert transformer.output_kms_key == kms_key |
486 |
| - assert transformer.accept == accept |
487 |
| - assert transformer.max_concurrent_transforms == max_concurrent_transforms |
488 |
| - assert transformer.max_payload == max_payload |
489 |
| - assert transformer.env == env |
490 |
| - assert transformer.tags == tags |
491 |
| - assert transformer.volume_kms_key == kms_key |
492 |
| - |
493 |
| - |
494 |
| -@patch("sagemaker.model.Model._create_sagemaker_model") |
495 |
| -def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): |
496 |
| - model = Model( |
497 |
| - MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True |
498 |
| - ) |
499 |
| - |
500 |
| - transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"}) |
501 |
| - assert transformer.env is None |
502 |
| - |
503 |
| - |
504 |
| -@patch("sagemaker.session.Session") |
505 |
| -@patch("sagemaker.local.LocalSession") |
506 |
| -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
507 |
| -def test_transformer_creates_correct_session(local_session, session): |
508 |
| - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) |
509 |
| - transformer = model.transformer(instance_count=1, instance_type="local") |
510 |
| - assert model.sagemaker_session == local_session.return_value |
511 |
| - assert transformer.sagemaker_session == local_session.return_value |
512 |
| - |
513 |
| - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) |
514 |
| - transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") |
515 |
| - assert model.sagemaker_session == session.return_value |
516 |
| - assert transformer.sagemaker_session == session.return_value |
517 |
| - |
518 |
| - |
519 |
| -def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): |
520 |
| - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
521 |
| - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE |
522 |
| - ) |
523 |
| - |
524 |
| - model_package = ModelPackage( |
525 |
| - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
526 |
| - ) |
527 |
| - assert model_package.enable_network_isolation() is False |
528 |
| - |
529 |
| - |
530 |
| -def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): |
531 |
| - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) |
532 |
| - model_package_response["InferenceSpecification"]["Containers"].append( |
533 |
| - { |
534 |
| - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", |
535 |
| - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
536 |
| - "ProductId": "some-product-id", |
537 |
| - } |
538 |
| - ) |
539 |
| - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
540 |
| - return_value=model_package_response |
541 |
| - ) |
542 |
| - |
543 |
| - model_package = ModelPackage( |
544 |
| - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
545 |
| - ) |
546 |
| - assert model_package.enable_network_isolation() is True |
547 |
| - |
548 |
| - |
549 |
| -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) |
550 |
| -def test_model_package_create_transformer(sagemaker_session): |
551 |
| - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
552 |
| - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE |
553 |
| - ) |
554 |
| - |
555 |
| - model_package = ModelPackage( |
556 |
| - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
557 |
| - ) |
558 |
| - model_package.name = "auto-generated-model" |
559 |
| - transformer = model_package.transformer( |
560 |
| - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} |
561 |
| - ) |
562 |
| - assert isinstance(transformer, sagemaker.transformer.Transformer) |
563 |
| - assert transformer.model_name == "auto-generated-model" |
564 |
| - assert transformer.instance_type == "ml.m4.xlarge" |
565 |
| - assert transformer.env == {"test": True} |
566 |
| - |
567 |
| - |
568 |
| -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) |
569 |
| -def test_model_package_create_transformer_with_product_id(sagemaker_session): |
570 |
| - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) |
571 |
| - model_package_response["InferenceSpecification"]["Containers"].append( |
572 |
| - { |
573 |
| - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", |
574 |
| - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
575 |
| - "ProductId": "some-product-id", |
576 |
| - } |
577 |
| - ) |
578 |
| - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
579 |
| - return_value=model_package_response |
580 |
| - ) |
581 |
| - |
582 |
| - model_package = ModelPackage( |
583 |
| - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
584 |
| - ) |
585 |
| - model_package.name = "auto-generated-model" |
586 |
| - transformer = model_package.transformer( |
587 |
| - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} |
588 |
| - ) |
589 |
| - assert isinstance(transformer, sagemaker.transformer.Transformer) |
590 |
| - assert transformer.model_name == "auto-generated-model" |
591 |
| - assert transformer.instance_type == "ml.m4.xlarge" |
592 |
| - assert transformer.env is None |
593 |
| - |
594 |
| - |
595 | 385 | @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
|
596 | 386 | @patch("time.strftime", MagicMock(return_value=TIMESTAMP))
|
597 | 387 | def test_model_delete_model(sagemaker_session, tmpdir):
|
|
0 commit comments