|
20 | 20 | from botocore.exceptions import WaiterError
|
21 | 21 |
|
22 | 22 | import tests
|
| 23 | +from sagemaker.tensorflow import TensorFlow, TensorFlowModel |
23 | 24 | from tests.integ.retry import retries
|
24 | 25 | from sagemaker.drift_check_baselines import DriftCheckBaselines
|
25 | 26 | from sagemaker import (
|
@@ -745,3 +746,101 @@ def test_model_registration_with_model_repack(
|
745 | 746 | pipeline.delete()
|
746 | 747 | except Exception:
|
747 | 748 | pass
|
| 749 | + |
| 750 | + |
| 751 | +def test_model_registration_with_tensorflow_model_with_pipeline_model( |
| 752 | + sagemaker_session, role, tf_full_version, tf_full_py_version, pipeline_name, region_name |
| 753 | +): |
| 754 | + base_dir = os.path.join(DATA_DIR, "tensorflow_mnist") |
| 755 | + entry_point = os.path.join(base_dir, "mnist_v2.py") |
| 756 | + input_path = sagemaker_session.upload_data( |
| 757 | + path=os.path.join(base_dir, "data"), |
| 758 | + key_prefix="integ-test-data/tf-scriptmode/mnist/training", |
| 759 | + ) |
| 760 | + inputs = TrainingInput(s3_data=input_path) |
| 761 | + |
| 762 | + instance_count = ParameterInteger(name="InstanceCount", default_value=1) |
| 763 | + instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") |
| 764 | + |
| 765 | + tensorflow_estimator = TensorFlow( |
| 766 | + entry_point=entry_point, |
| 767 | + role=role, |
| 768 | + instance_count=instance_count, |
| 769 | + instance_type=instance_type, |
| 770 | + framework_version=tf_full_version, |
| 771 | + py_version=tf_full_py_version, |
| 772 | + sagemaker_session=sagemaker_session, |
| 773 | + ) |
| 774 | + step_train = TrainingStep( |
| 775 | + name="MyTrain", |
| 776 | + estimator=tensorflow_estimator, |
| 777 | + inputs=inputs, |
| 778 | + ) |
| 779 | + |
| 780 | + model = TensorFlowModel( |
| 781 | + entry_point=entry_point, |
| 782 | + framework_version="2.4", |
| 783 | + model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, |
| 784 | + role=role, |
| 785 | + sagemaker_session=sagemaker_session, |
| 786 | + ) |
| 787 | + |
| 788 | + pipeline_model = PipelineModel( |
| 789 | + name="MyModelPipeline", models=[model], role=role, sagemaker_session=sagemaker_session |
| 790 | + ) |
| 791 | + |
| 792 | + step_register_model = RegisterModel( |
| 793 | + name="MyRegisterModel", |
| 794 | + model=pipeline_model, |
| 795 | + model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, |
| 796 | + content_types=["application/json"], |
| 797 | + response_types=["application/json"], |
| 798 | + inference_instances=["ml.t2.medium", "ml.m5.large"], |
| 799 | + transform_instances=["ml.m5.large"], |
| 800 | + model_package_group_name=f"{pipeline_name}TestModelPackageGroup", |
| 801 | + ) |
| 802 | + |
| 803 | + pipeline = Pipeline( |
| 804 | + name=pipeline_name, |
| 805 | + parameters=[ |
| 806 | + instance_count, |
| 807 | + instance_type, |
| 808 | + ], |
| 809 | + steps=[step_train, step_register_model], |
| 810 | + sagemaker_session=sagemaker_session, |
| 811 | + ) |
| 812 | + |
| 813 | + try: |
| 814 | + response = pipeline.create(role) |
| 815 | + create_arn = response["PipelineArn"] |
| 816 | + |
| 817 | + assert re.match( |
| 818 | + rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", |
| 819 | + create_arn, |
| 820 | + ) |
| 821 | + |
| 822 | + for _ in retries( |
| 823 | + max_retry_count=5, |
| 824 | + exception_message_prefix="Waiting for a successful execution of pipeline", |
| 825 | + seconds_to_sleep=10, |
| 826 | + ): |
| 827 | + execution = pipeline.start(parameters={}) |
| 828 | + assert re.match( |
| 829 | + rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
| 830 | + execution.arn, |
| 831 | + ) |
| 832 | + try: |
| 833 | + execution.wait(delay=30, max_attempts=60) |
| 834 | + except WaiterError: |
| 835 | + pass |
| 836 | + execution_steps = execution.list_steps() |
| 837 | + |
| 838 | + assert len(execution_steps) == 3 |
| 839 | + for step in execution_steps: |
| 840 | + assert step["StepStatus"] == "Succeeded" |
| 841 | + break |
| 842 | + finally: |
| 843 | + try: |
| 844 | + pipeline.delete() |
| 845 | + except Exception: |
| 846 | + pass |
0 commit comments