24
24
INSTANCE_TYPE = "ml.c4.xlarge"
25
25
NUM_SAMPLES_PER_TREE = 20
26
26
NUM_TREES = 50
27
+ EVAL_METRICS = ["accuracy" , "precision_recall_fscore" ]
27
28
28
29
COMMON_TRAIN_ARGS = {
29
30
"role" : ROLE ,
@@ -70,13 +71,15 @@ def test_init_required_positional(sagemaker_session):
70
71
INSTANCE_TYPE ,
71
72
NUM_SAMPLES_PER_TREE ,
72
73
NUM_TREES ,
74
+ EVAL_METRICS ,
73
75
sagemaker_session = sagemaker_session ,
74
76
)
75
77
assert randomcutforest .role == ROLE
76
78
assert randomcutforest .instance_count == INSTANCE_COUNT
77
79
assert randomcutforest .instance_type == INSTANCE_TYPE
78
80
assert randomcutforest .num_trees == NUM_TREES
79
81
assert randomcutforest .num_samples_per_tree == NUM_SAMPLES_PER_TREE
82
+ assert randomcutforest .eval_metrics == EVAL_METRICS
80
83
81
84
82
85
def test_init_required_named (sagemaker_session ):
@@ -92,10 +95,13 @@ def test_all_hyperparameters(sagemaker_session):
92
95
sagemaker_session = sagemaker_session ,
93
96
num_trees = NUM_TREES ,
94
97
num_samples_per_tree = NUM_SAMPLES_PER_TREE ,
98
+ eval_metrics = EVAL_METRICS ,
95
99
** ALL_REQ_ARGS
96
100
)
97
101
assert randomcutforest .hyperparameters () == dict (
98
- num_samples_per_tree = str (NUM_SAMPLES_PER_TREE ), num_trees = str (NUM_TREES ),
102
+ num_samples_per_tree = str (NUM_SAMPLES_PER_TREE ),
103
+ num_trees = str (NUM_TREES ),
104
+ eval_metrics = '["accuracy", "precision_recall_fscore"]' ,
99
105
)
100
106
101
107
@@ -104,7 +110,7 @@ def test_image(sagemaker_session):
104
110
assert image_uris .retrieve ("randomcutforest" , REGION ) == randomcutforest .train_image ()
105
111
106
112
107
- @pytest .mark .parametrize ("iterable_hyper_parameters, value" , [("eval_metrics" , [ 0 ] )])
113
+ @pytest .mark .parametrize ("iterable_hyper_parameters, value" , [("eval_metrics" , 0 )])
108
114
def test_iterable_hyper_parameters_type (sagemaker_session , iterable_hyper_parameters , value ):
109
115
with pytest .raises (TypeError ):
110
116
test_params = ALL_REQ_ARGS .copy ()
0 commit comments