23
23
24
24
from tests .integ import test_local_mode
25
25
from tests .unit import SAGEMAKER_CONFIG_TRANSFORM_JOB
26
+ from sagemaker .model_monitor import DatasetFormat
27
+ from sagemaker .workflow .quality_check_step import (
28
+ ModelQualityCheckConfig ,
29
+ )
30
+ from sagemaker .workflow .check_job_config import CheckJobConfig
31
+
32
+ _CHECK_JOB_PREFIX = "CheckJobPrefix"
26
33
27
34
ROLE = "DummyRole"
28
35
REGION = "us-west-2"
49
56
"base_transform_job_name" : JOB_NAME ,
50
57
}
51
58
59
+ PROCESS_REQUEST_ARGS = {
60
+ "inputs" : "processing_inputs" ,
61
+ "output_config" : "output_config" ,
62
+ "job_name" : "job_name" ,
63
+ "resources" : "resource_config" ,
64
+ "stopping_condition" : {"MaxRuntimeInSeconds" : 3600 },
65
+ "app_specification" : "app_specification" ,
66
+ "experiment_config" : {"ExperimentName" : "AnExperiment" },
67
+ }
68
+
52
69
MODEL_DESC_PRIMARY_CONTAINER = {"PrimaryContainer" : {"Image" : IMAGE_URI }}
53
70
54
71
MODEL_DESC_CONTAINERS_ONLY = {"Containers" : [{"Image" : IMAGE_URI }]}
@@ -72,7 +89,7 @@ def mock_create_tar_file():
72
89
73
90
@pytest .fixture ()
74
91
def sagemaker_session ():
75
- boto_mock = Mock (name = "boto_session" )
92
+ boto_mock = Mock (name = "boto_session" , region_name = REGION )
76
93
session = Mock (
77
94
name = "sagemaker_session" ,
78
95
boto_session = boto_mock ,
@@ -764,6 +781,48 @@ def test_stop_transform_job(sagemaker_session, transformer):
764
781
sagemaker_session .stop_transform_job .assert_called_once_with (name = JOB_NAME )
765
782
766
783
784
+ @patch ("sagemaker.transformer.Transformer._retrieve_image_uri" , return_value = IMAGE_URI )
785
+ @patch ("sagemaker.workflow.pipeline.Pipeline.upsert" , return_value = {})
786
+ @patch ("sagemaker.workflow.pipeline.Pipeline.start" , return_value = Mock ())
787
+ def test_transform_with_monitoring_create_and_starts_pipeline (
788
+ pipeline_start , upsert , image_uri , sagemaker_session , transformer
789
+ ):
790
+
791
+ config = CheckJobConfig (
792
+ role = ROLE ,
793
+ instance_count = 1 ,
794
+ instance_type = "ml.m5.xlarge" ,
795
+ volume_size_in_gb = 60 ,
796
+ max_runtime_in_seconds = 1800 ,
797
+ sagemaker_session = sagemaker_session ,
798
+ base_job_name = _CHECK_JOB_PREFIX ,
799
+ )
800
+
801
+ quality_check_config = ModelQualityCheckConfig (
802
+ baseline_dataset = "s3://baseline_dataset_s3_url" ,
803
+ dataset_format = DatasetFormat .csv (header = True ),
804
+ problem_type = "BinaryClassification" ,
805
+ inference_attribute = "quality_cfg_attr_value" ,
806
+ probability_attribute = "quality_cfg_attr_value" ,
807
+ ground_truth_attribute = "quality_cfg_attr_value" ,
808
+ probability_threshold_attribute = "quality_cfg_attr_value" ,
809
+ post_analytics_processor_script = "s3://my_bucket/data_quality/postprocessor.py" ,
810
+ output_s3_uri = "s3://output_s3_uri" ,
811
+ )
812
+
813
+ transformer .transform_with_monitoring (
814
+ monitoring_config = quality_check_config ,
815
+ monitoring_resource_config = config ,
816
+ data = DATA ,
817
+ content_type = "text/libsvm" ,
818
+ supplied_baseline_constraints = "supplied_baseline_constraints" ,
819
+ role = ROLE ,
820
+ )
821
+
822
+ upsert .assert_called_once ()
823
+ pipeline_start .assert_called_once ()
824
+
825
+
767
826
def test_stop_transform_job_no_transform_job (transformer ):
768
827
with pytest .raises (ValueError ) as e :
769
828
transformer .stop_transform_job ()
0 commit comments