|
20 | 20 | from botocore.exceptions import WaiterError
|
21 | 21 |
|
22 | 22 | import tests
|
23 |
| -from sagemaker.tensorflow import TensorFlow, TensorFlowModel |
24 | 23 | from tests.integ.retry import retries
|
25 | 24 | from sagemaker.drift_check_baselines import DriftCheckBaselines
|
26 | 25 | from sagemaker import (
|
@@ -746,101 +745,3 @@ def test_model_registration_with_model_repack(
|
746 | 745 | pipeline.delete()
|
747 | 746 | except Exception:
|
748 | 747 | pass
|
749 |
| - |
750 |
| - |
751 |
| -def test_model_registration_with_tensorflow_model_with_pipeline_model( |
752 |
| - sagemaker_session, role, tf_full_version, tf_full_py_version, pipeline_name, region_name |
753 |
| -): |
754 |
| - base_dir = os.path.join(DATA_DIR, "tensorflow_mnist") |
755 |
| - entry_point = os.path.join(base_dir, "mnist_v2.py") |
756 |
| - input_path = sagemaker_session.upload_data( |
757 |
| - path=os.path.join(base_dir, "data"), |
758 |
| - key_prefix="integ-test-data/tf-scriptmode/mnist/training", |
759 |
| - ) |
760 |
| - inputs = TrainingInput(s3_data=input_path) |
761 |
| - |
762 |
| - instance_count = ParameterInteger(name="InstanceCount", default_value=1) |
763 |
| - instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") |
764 |
| - |
765 |
| - tensorflow_estimator = TensorFlow( |
766 |
| - entry_point=entry_point, |
767 |
| - role=role, |
768 |
| - instance_count=instance_count, |
769 |
| - instance_type=instance_type, |
770 |
| - framework_version=tf_full_version, |
771 |
| - py_version=tf_full_py_version, |
772 |
| - sagemaker_session=sagemaker_session, |
773 |
| - ) |
774 |
| - step_train = TrainingStep( |
775 |
| - name="MyTrain", |
776 |
| - estimator=tensorflow_estimator, |
777 |
| - inputs=inputs, |
778 |
| - ) |
779 |
| - |
780 |
| - model = TensorFlowModel( |
781 |
| - entry_point=entry_point, |
782 |
| - framework_version="2.4", |
783 |
| - model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, |
784 |
| - role=role, |
785 |
| - sagemaker_session=sagemaker_session, |
786 |
| - ) |
787 |
| - |
788 |
| - pipeline_model = PipelineModel( |
789 |
| - name="MyModelPipeline", models=[model], role=role, sagemaker_session=sagemaker_session |
790 |
| - ) |
791 |
| - |
792 |
| - step_register_model = RegisterModel( |
793 |
| - name="MyRegisterModel", |
794 |
| - model=pipeline_model, |
795 |
| - model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, |
796 |
| - content_types=["application/json"], |
797 |
| - response_types=["application/json"], |
798 |
| - inference_instances=["ml.t2.medium", "ml.m5.large"], |
799 |
| - transform_instances=["ml.m5.large"], |
800 |
| - model_package_group_name=f"{pipeline_name}TestModelPackageGroup", |
801 |
| - ) |
802 |
| - |
803 |
| - pipeline = Pipeline( |
804 |
| - name=pipeline_name, |
805 |
| - parameters=[ |
806 |
| - instance_count, |
807 |
| - instance_type, |
808 |
| - ], |
809 |
| - steps=[step_train, step_register_model], |
810 |
| - sagemaker_session=sagemaker_session, |
811 |
| - ) |
812 |
| - |
813 |
| - try: |
814 |
| - response = pipeline.create(role) |
815 |
| - create_arn = response["PipelineArn"] |
816 |
| - |
817 |
| - assert re.match( |
818 |
| - rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", |
819 |
| - create_arn, |
820 |
| - ) |
821 |
| - |
822 |
| - for _ in retries( |
823 |
| - max_retry_count=5, |
824 |
| - exception_message_prefix="Waiting for a successful execution of pipeline", |
825 |
| - seconds_to_sleep=10, |
826 |
| - ): |
827 |
| - execution = pipeline.start(parameters={}) |
828 |
| - assert re.match( |
829 |
| - rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", |
830 |
| - execution.arn, |
831 |
| - ) |
832 |
| - try: |
833 |
| - execution.wait(delay=30, max_attempts=60) |
834 |
| - except WaiterError: |
835 |
| - pass |
836 |
| - execution_steps = execution.list_steps() |
837 |
| - |
838 |
| - assert len(execution_steps) == 3 |
839 |
| - for step in execution_steps: |
840 |
| - assert step["StepStatus"] == "Succeeded" |
841 |
| - break |
842 |
| - finally: |
843 |
| - try: |
844 |
| - pipeline.delete() |
845 |
| - except Exception: |
846 |
| - pass |
0 commit comments