16
16
import pytest
17
17
18
18
from sagemaker .model_monitor import DatasetFormat
19
- from sagemaker .workflow .parameters import ParameterString
19
+ from sagemaker .workflow .execution_variables import ExecutionVariable
20
+ from sagemaker .workflow .parameters import ParameterString , ParameterInteger
20
21
from sagemaker .workflow .pipeline import Pipeline
21
22
from sagemaker .workflow .pipeline import PipelineDefinitionConfig
22
23
from sagemaker .workflow .quality_check_step import (
178
179
"dataset_source" : "/opt/ml/processing/input/baseline_dataset_input" ,
179
180
"analysis_type" : "MODEL_QUALITY" ,
180
181
"problem_type" : "BinaryClassification" ,
181
- "probability_attribute" : "0" ,
182
- "probability_threshold_attribute" : "0.5" ,
183
182
},
184
183
"StoppingCondition" : {"MaxRuntimeInSeconds" : 1800 },
185
184
},
@@ -269,23 +268,54 @@ def test_data_quality_check_step(
269
268
assert step_definition == _expected_data_quality_dsl
270
269
271
270
271
+ @pytest .mark .parametrize (
272
+ "quality_cfg_attr_value, expected_value_in_dsl" ,
273
+ [
274
+ (0 , "0" ),
275
+ ("attr" , "attr" ),
276
+ (None , None ),
277
+ (ParameterString (name = "ParamStringEnvVar" ), {"Get" : "Parameters.ParamStringEnvVar" }),
278
+ (ExecutionVariable ("PipelineArn" ), {"Get" : "Execution.PipelineArn" }),
279
+ (ParameterInteger (name = "ParamIntEnvVar" ), "Error" ),
280
+ ],
281
+ )
272
282
def test_model_quality_check_step (
273
283
sagemaker_session ,
274
284
check_job_config ,
275
285
model_package_group_name ,
276
286
supplied_baseline_statistics_uri ,
277
287
supplied_baseline_constraints_uri ,
288
+ quality_cfg_attr_value ,
289
+ expected_value_in_dsl ,
278
290
):
279
291
model_quality_check_config = ModelQualityCheckConfig (
280
292
baseline_dataset = "baseline_dataset_s3_url" ,
281
293
dataset_format = DatasetFormat .csv (header = True ),
282
294
problem_type = "BinaryClassification" ,
283
- probability_attribute = 0 , # the integer should be converted to str by SDK
284
- ground_truth_attribute = None ,
285
- probability_threshold_attribute = 0.5 , # the float should be converted to str by SDK
295
+ inference_attribute = quality_cfg_attr_value ,
296
+ probability_attribute = quality_cfg_attr_value ,
297
+ ground_truth_attribute = quality_cfg_attr_value ,
298
+ probability_threshold_attribute = quality_cfg_attr_value ,
286
299
post_analytics_processor_script = "s3://my_bucket/data_quality/postprocessor.py" ,
287
300
output_s3_uri = "" ,
288
301
)
302
+
303
+ if expected_value_in_dsl == "Error" :
304
+ with pytest .raises (ValueError ) as err :
305
+ QualityCheckStep (
306
+ name = "ModelQualityCheckStep" ,
307
+ register_new_baseline = False ,
308
+ skip_check = False ,
309
+ fail_on_violation = True ,
310
+ quality_check_config = model_quality_check_config ,
311
+ check_job_config = check_job_config ,
312
+ model_package_group_name = model_package_group_name ,
313
+ supplied_baseline_statistics = supplied_baseline_statistics_uri ,
314
+ supplied_baseline_constraints = supplied_baseline_constraints_uri ,
315
+ )
316
+ assert "cannot be Parameter types other than ParameterString" in str (err )
317
+ return
318
+
289
319
model_quality_check_step = QualityCheckStep (
290
320
name = "ModelQualityCheckStep" ,
291
321
register_new_baseline = False ,
@@ -297,6 +327,7 @@ def test_model_quality_check_step(
297
327
supplied_baseline_statistics = supplied_baseline_statistics_uri ,
298
328
supplied_baseline_constraints = supplied_baseline_constraints_uri ,
299
329
)
330
+
300
331
pipeline = Pipeline (
301
332
name = "MyPipeline" ,
302
333
parameters = [
@@ -310,6 +341,16 @@ def test_model_quality_check_step(
310
341
311
342
step_definition = _get_step_definition_for_test (pipeline )
312
343
344
+ step_def_env = step_definition ["Arguments" ]["Environment" ]
345
+ for var in [
346
+ "inference_attribute" ,
347
+ "probability_attribute" ,
348
+ "ground_truth_attribute" ,
349
+ "probability_threshold_attribute" ,
350
+ ]:
351
+ env_var_dsl = step_def_env .pop (var , None )
352
+ assert env_var_dsl == expected_value_in_dsl
353
+
313
354
assert step_definition == _expected_model_quality_dsl
314
355
315
356
0 commit comments