|
19 | 19 | import sagemaker
|
20 | 20 | from sagemaker.model import FrameworkModel, Model
|
21 | 21 | from sagemaker.huggingface.model import HuggingFaceModel
|
22 |
| -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET |
| 22 | +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME |
23 | 23 | from sagemaker.jumpstart.enums import JumpStartTag
|
24 | 24 | from sagemaker.mxnet.model import MXNetModel
|
25 | 25 | from sagemaker.pytorch.model import PyTorchModel
|
@@ -569,3 +569,93 @@ def test_all_framework_models_add_jumpstart_tags(
|
569 | 569 |
|
570 | 570 | sagemaker_session.create_model.reset_mock()
|
571 | 571 | sagemaker_session.endpoint_from_production_variants.reset_mock()
|
| 572 | + |
| 573 | + |
| 574 | +@patch("sagemaker.utils.repack_model") |
| 575 | +def test_script_mode_model_uses_jumpstart_base_name(repack_model, sagemaker_session): |
| 576 | + |
| 577 | + jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" |
| 578 | + t = Model( |
| 579 | + entry_point=ENTRY_POINT_INFERENCE, |
| 580 | + role=ROLE, |
| 581 | + sagemaker_session=sagemaker_session, |
| 582 | + source_dir=jumpstart_source_dir, |
| 583 | + image_uri=IMAGE_URI, |
| 584 | + model_data=MODEL_DATA, |
| 585 | + ) |
| 586 | + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) |
| 587 | + |
| 588 | + assert sagemaker_session.create_model.call_args_list[0][0][0].startswith( |
| 589 | + JUMPSTART_RESOURCE_BASE_NAME |
| 590 | + ) |
| 591 | + |
| 592 | + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith( |
| 593 | + JUMPSTART_RESOURCE_BASE_NAME |
| 594 | + ) |
| 595 | + |
| 596 | + sagemaker_session.create_model.reset_mock() |
| 597 | + sagemaker_session.endpoint_from_production_variants.reset_mock() |
| 598 | + |
| 599 | + non_jumpstart_source_dir = "s3://blah/blah/blah" |
| 600 | + t = Model( |
| 601 | + entry_point=ENTRY_POINT_INFERENCE, |
| 602 | + role=ROLE, |
| 603 | + sagemaker_session=sagemaker_session, |
| 604 | + source_dir=non_jumpstart_source_dir, |
| 605 | + image_uri=IMAGE_URI, |
| 606 | + model_data=MODEL_DATA, |
| 607 | + ) |
| 608 | + t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) |
| 609 | + |
| 610 | + assert not sagemaker_session.create_model.call_args_list[0][0][0].startswith( |
| 611 | + JUMPSTART_RESOURCE_BASE_NAME |
| 612 | + ) |
| 613 | + |
| 614 | + assert not sagemaker_session.endpoint_from_production_variants.call_args_list[0][1][ |
| 615 | + "name" |
| 616 | + ].startswith(JUMPSTART_RESOURCE_BASE_NAME) |
| 617 | + |
| 618 | + |
| 619 | +@patch("sagemaker.utils.repack_model") |
| 620 | +@patch("sagemaker.fw_utils.tar_and_upload_dir") |
| 621 | +def test_all_framework_models_add_jumpstart_base_name( |
| 622 | + repack_model, tar_and_uload_dir, sagemaker_session |
| 623 | +): |
| 624 | + framework_model_classes_to_kwargs = { |
| 625 | + PyTorchModel: {"framework_version": "1.5.0", "py_version": "py3"}, |
| 626 | + TensorFlowModel: { |
| 627 | + "framework_version": "2.3", |
| 628 | + }, |
| 629 | + HuggingFaceModel: { |
| 630 | + "pytorch_version": "1.7.1", |
| 631 | + "py_version": "py36", |
| 632 | + "transformers_version": "4.6.1", |
| 633 | + }, |
| 634 | + MXNetModel: {"framework_version": "1.7.0", "py_version": "py3"}, |
| 635 | + SKLearnModel: { |
| 636 | + "framework_version": "0.23-1", |
| 637 | + }, |
| 638 | + XGBoostModel: { |
| 639 | + "framework_version": "1.3-1", |
| 640 | + }, |
| 641 | + } |
| 642 | + jumpstart_model_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" |
| 643 | + for framework_model_class, kwargs in framework_model_classes_to_kwargs.items(): |
| 644 | + framework_model_class( |
| 645 | + entry_point=ENTRY_POINT_INFERENCE, |
| 646 | + role=ROLE, |
| 647 | + sagemaker_session=sagemaker_session, |
| 648 | + model_data=jumpstart_model_dir, |
| 649 | + **kwargs, |
| 650 | + ).deploy(instance_type="ml.m2.xlarge", initial_instance_count=INSTANCE_COUNT) |
| 651 | + |
| 652 | + assert sagemaker_session.create_model.call_args_list[0][0][0].startswith( |
| 653 | + JUMPSTART_RESOURCE_BASE_NAME |
| 654 | + ) |
| 655 | + |
| 656 | + assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith( |
| 657 | + JUMPSTART_RESOURCE_BASE_NAME |
| 658 | + ) |
| 659 | + |
| 660 | + sagemaker_session.create_model.reset_mock() |
| 661 | + sagemaker_session.endpoint_from_production_variants.reset_mock() |
0 commit comments