|
62 | 62 | CreateModelStep,
|
63 | 63 | CacheConfig,
|
64 | 64 | )
|
| 65 | +from sagemaker.pipeline import PipelineModel |
| 66 | +from sagemaker.sparkml import SparkMLModel |
| 67 | +from sagemaker.predictor import Predictor |
| 68 | +from sagemaker.model import FrameworkModel |
65 | 69 | from tests.unit import DATA_DIR
|
66 | 70 |
|
67 | 71 | DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
|
@@ -89,6 +93,21 @@ def properties(self):
|
89 | 93 | return self._properties
|
90 | 94 |
|
91 | 95 |
|
| 96 | +class DummyFrameworkModel(FrameworkModel): |
| 97 | + def __init__(self, sagemaker_session, **kwargs): |
| 98 | + super(DummyFrameworkModel, self).__init__( |
| 99 | + "s3://bucket/model_1.tar.gz", |
| 100 | + "mi-1", |
| 101 | + ROLE, |
| 102 | + os.path.join(DATA_DIR, "dummy_script.py"), |
| 103 | + sagemaker_session=sagemaker_session, |
| 104 | + **kwargs, |
| 105 | + ) |
| 106 | + |
| 107 | + def create_predictor(self, endpoint_name): |
| 108 | + return Predictor(endpoint_name, self.sagemaker_session) |
| 109 | + |
| 110 | + |
92 | 111 | @pytest.fixture
|
93 | 112 | def boto_session():
|
94 | 113 | role_mock = Mock()
|
@@ -704,6 +723,63 @@ def test_create_model_step(sagemaker_session):
|
704 | 723 | assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"}
|
705 | 724 |
|
706 | 725 |
|
| 726 | +@patch("tarfile.open") |
| 727 | +@patch("time.strftime", return_value="2017-10-10-14-14-15") |
| 728 | +def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session): |
| 729 | + framework_model = DummyFrameworkModel(sagemaker_session) |
| 730 | + sparkml_model = SparkMLModel( |
| 731 | + model_data="s3://bucket/model_2.tar.gz", |
| 732 | + role=ROLE, |
| 733 | + sagemaker_session=sagemaker_session, |
| 734 | + env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, |
| 735 | + ) |
| 736 | + model = PipelineModel( |
| 737 | + models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session |
| 738 | + ) |
| 739 | + inputs = CreateModelInput( |
| 740 | + instance_type="c4.4xlarge", |
| 741 | + accelerator_type="ml.eia1.medium", |
| 742 | + ) |
| 743 | + step = CreateModelStep( |
| 744 | + name="MyCreateModelStep", |
| 745 | + depends_on=["TestStep"], |
| 746 | + display_name="MyCreateModelStep", |
| 747 | + description="TestDescription", |
| 748 | + model=model, |
| 749 | + inputs=inputs, |
| 750 | + ) |
| 751 | + step.add_depends_on(["SecondTestStep"]) |
| 752 | + |
| 753 | + assert step.to_request() == { |
| 754 | + "Name": "MyCreateModelStep", |
| 755 | + "Type": "Model", |
| 756 | + "Description": "TestDescription", |
| 757 | + "DisplayName": "MyCreateModelStep", |
| 758 | + "DependsOn": ["TestStep", "SecondTestStep"], |
| 759 | + "Arguments": { |
| 760 | + "Containers": [ |
| 761 | + { |
| 762 | + "Environment": { |
| 763 | + "SAGEMAKER_PROGRAM": "dummy_script.py", |
| 764 | + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz", |
| 765 | + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", |
| 766 | + "SAGEMAKER_REGION": "us-west-2", |
| 767 | + }, |
| 768 | + "Image": "mi-1", |
| 769 | + "ModelDataUrl": "s3://bucket/model_1.tar.gz", |
| 770 | + }, |
| 771 | + { |
| 772 | + "Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, |
| 773 | + "Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4", |
| 774 | + "ModelDataUrl": "s3://bucket/model_2.tar.gz", |
| 775 | + }, |
| 776 | + ], |
| 777 | + "ExecutionRoleArn": "DummyRole", |
| 778 | + }, |
| 779 | + } |
| 780 | + assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"} |
| 781 | + |
| 782 | + |
707 | 783 | def test_transform_step(sagemaker_session):
|
708 | 784 | transformer = Transformer(
|
709 | 785 | model_name=MODEL_NAME,
|
|
0 commit comments