Skip to content

Commit d857533

Browse files
author
Chuyang Deng
committed
recover eval_metrics for kmeans and rcf
1 parent b8dcc99 commit d857533

10 files changed

+45
-5
lines changed

doc/algorithms/randomcutforest.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The Amazon SageMaker Random Cut Forest algorithm.
88
:undoc-members:
99
:show-inheritance:
1010
:inherited-members:
11-
:exclude-members: image_uri, num_trees, num_samples_per_tree, feature_dim, MINI_BATCH_SIZE
11+
:exclude-members: image_uri, num_trees, num_samples_per_tree, eval_metrics, feature_dim, MINI_BATCH_SIZE
1212

1313

1414
.. autoclass:: sagemaker.RandomCutForestModel

src/sagemaker/amazon/kmeans.py

+12
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class KMeans(AmazonAlgorithmEstimatorBase):
4343
)
4444
epochs = hp("epochs", gt(0), "An integer greater-than 0", int)
4545
center_factor = hp("extra_center_factor", gt(0), "An integer greater-than 0", int)
46+
eval_metrics = hp(
47+
name="eval_metrics",
48+
validation_message='A comma separated list of "msd" or "ssd"',
49+
data_type=list,
50+
)
4651

4752
def __init__(
4853
self,
@@ -58,6 +63,7 @@ def __init__(
5863
half_life_time_size=None,
5964
epochs=None,
6065
center_factor=None,
66+
eval_metrics=None,
6167
**kwargs
6268
):
6369
"""A k-means clustering
@@ -124,6 +130,11 @@ def __init__(
124130
center_factor (int): The algorithm will create
125131
``num_clusters * extra_center_factor`` as it runs and reduce the
126132
number of centers to ``k`` when finalizing
133+
eval_metrics (list): JSON list of metrics types to be used for
134+
reporting the score for the model. Allowed values are "msd"
135+
Means Square Error, "ssd": Sum of square distance. If test data
136+
is provided, the score shall be reported in terms of all
137+
requested metrics.
127138
**kwargs: base class keyword argument values.
128139
129140
.. tip::
@@ -142,6 +153,7 @@ def __init__(
142153
self.half_life_time_size = half_life_time_size
143154
self.epochs = epochs
144155
self.center_factor = center_factor
156+
self.eval_metrics = eval_metrics
145157

146158
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
147159
"""Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing

src/sagemaker/amazon/randomcutforest.py

+13
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ class RandomCutForest(AmazonAlgorithmEstimatorBase):
3131
repo_version = 1
3232
MINI_BATCH_SIZE = 1000
3333

34+
eval_metrics = hp(
35+
name="eval_metrics",
36+
validation_message='A comma separated list of "accuracy" or "precision_recall_fscore"',
37+
data_type=list,
38+
)
39+
3440
num_trees = hp("num_trees", (ge(50), le(1000)), "An integer in [50, 1000]", int)
3541
num_samples_per_tree = hp(
3642
"num_samples_per_tree", (ge(1), le(2048)), "An integer in [1, 2048]", int
@@ -44,6 +50,7 @@ def __init__(
4450
instance_type,
4551
num_samples_per_tree=None,
4652
num_trees=None,
53+
eval_metrics=None,
4754
**kwargs
4855
):
4956
"""RandomCutForest is :class:`Estimator` used for anomaly detection.
@@ -92,6 +99,11 @@ def __init__(
9299
build each tree in the forest. The total number of samples drawn
93100
from the train dataset is num_trees * num_samples_per_tree.
94101
num_trees (int): Optional. The number of trees used in the forest.
102+
eval_metrics (list): Optional. JSON list of metrics types to be used
103+
for reporting the score for the model. Allowed values are
104+
"accuracy", "precision_recall_fscore": positive and negative
105+
precision, recall, and f1 scores. If test data is provided, the
106+
score shall be reported in terms of all requested metrics.
95107
**kwargs: base class keyword argument values.
96108
97109
.. tip::
@@ -104,6 +116,7 @@ def __init__(
104116
super(RandomCutForest, self).__init__(role, instance_count, instance_type, **kwargs)
105117
self.num_samples_per_tree = num_samples_per_tree
106118
self.num_trees = num_trees
119+
self.eval_metrics = eval_metrics
107120

108121
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
109122
"""Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing

tests/integ/test_airflow_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
111111
kmeans.half_life_time_size = 1
112112
kmeans.epochs = 1
113113
kmeans.center_factor = 1
114+
kmeans.eval_metrics = ["ssd", "msd"]
114115

115116
records = kmeans.record_set(datasets.one_p_mnist()[0][:100])
116117

@@ -385,6 +386,7 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
385386
instance_type=cpu_instance_type,
386387
num_trees=50,
387388
num_samples_per_tree=20,
389+
eval_metrics=["accuracy", "precision_recall_fscore"],
388390
sagemaker_session=sagemaker_session,
389391
)
390392

tests/integ/test_kmeans.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import json
1516
import time
1617

1718
import pytest
@@ -46,6 +47,7 @@ def test_kmeans(sagemaker_session, cpu_instance_type, training_set):
4647
kmeans.half_life_time_size = 1
4748
kmeans.epochs = 1
4849
kmeans.center_factor = 1
50+
kmeans.eval_metrics = ["ssd", "msd"]
4951

5052
assert kmeans.hyperparameters() == dict(
5153
init_method=kmeans.init_method,
@@ -57,6 +59,7 @@ def test_kmeans(sagemaker_session, cpu_instance_type, training_set):
5759
epochs=str(kmeans.epochs),
5860
extra_center_factor=str(kmeans.center_factor),
5961
k=str(kmeans.k),
62+
eval_metrics=json.dumps(kmeans.eval_metrics),
6063
force_dense="True",
6164
)
6265

tests/integ/test_multidatamodel.py

+1
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def __rcf_training_job(
439439
instance_type=cpu_instance_type,
440440
num_trees=num_trees,
441441
num_samples_per_tree=num_samples_per_tree,
442+
eval_metrics=["accuracy", "precision_recall_fscore"],
442443
sagemaker_session=sagemaker_session,
443444
)
444445

tests/integ/test_randomcutforest.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_randomcutforest(sagemaker_session, cpu_instance_type):
3434
instance_type=cpu_instance_type,
3535
num_trees=50,
3636
num_samples_per_tree=20,
37+
eval_metrics=["accuracy", "precision_recall_fscore"],
3738
sagemaker_session=sagemaker_session,
3839
)
3940

tests/unit/test_kmeans.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def test_all_hyperparameters(sagemaker_session):
9292
half_life_time_size=0,
9393
epochs=10,
9494
center_factor=2,
95+
eval_metrics=["msd", "ssd"],
9596
**ALL_REQ_ARGS
9697
)
9798
assert kmeans.hyperparameters() == dict(
@@ -104,6 +105,7 @@ def test_all_hyperparameters(sagemaker_session):
104105
half_life_time_size="0",
105106
epochs="10",
106107
extra_center_factor="2",
108+
eval_metrics='["msd", "ssd"]',
107109
force_dense="True",
108110
)
109111

@@ -129,7 +131,7 @@ def test_required_hyper_parameters_value(sagemaker_session, required_hyper_param
129131
KMeans(sagemaker_session=sagemaker_session, **test_params)
130132

131133

132-
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", [0])])
134+
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)])
133135
def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value):
134136
with pytest.raises(TypeError):
135137
test_params = ALL_REQ_ARGS.copy()

tests/unit/test_linear_learner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def test_num_classes_can_be_string_for_multiclass_classifier(sagemaker_session):
216216
LinearLearner(sagemaker_session=sagemaker_session, **test_params)
217217

218218

219-
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", [0])])
219+
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("max_iterations", [0])])
220220
def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value):
221221
with pytest.raises(TypeError):
222222
test_params = ALL_REQ_ARGS.copy()

tests/unit/test_randomcutforest.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
INSTANCE_TYPE = "ml.c4.xlarge"
2525
NUM_SAMPLES_PER_TREE = 20
2626
NUM_TREES = 50
27+
EVAL_METRICS = ["accuracy", "precision_recall_fscore"]
2728

2829
COMMON_TRAIN_ARGS = {
2930
"role": ROLE,
@@ -70,13 +71,15 @@ def test_init_required_positional(sagemaker_session):
7071
INSTANCE_TYPE,
7172
NUM_SAMPLES_PER_TREE,
7273
NUM_TREES,
74+
EVAL_METRICS,
7375
sagemaker_session=sagemaker_session,
7476
)
7577
assert randomcutforest.role == ROLE
7678
assert randomcutforest.instance_count == INSTANCE_COUNT
7779
assert randomcutforest.instance_type == INSTANCE_TYPE
7880
assert randomcutforest.num_trees == NUM_TREES
7981
assert randomcutforest.num_samples_per_tree == NUM_SAMPLES_PER_TREE
82+
assert randomcutforest.eval_metrics == EVAL_METRICS
8083

8184

8285
def test_init_required_named(sagemaker_session):
@@ -92,10 +95,13 @@ def test_all_hyperparameters(sagemaker_session):
9295
sagemaker_session=sagemaker_session,
9396
num_trees=NUM_TREES,
9497
num_samples_per_tree=NUM_SAMPLES_PER_TREE,
98+
eval_metrics=EVAL_METRICS,
9599
**ALL_REQ_ARGS
96100
)
97101
assert randomcutforest.hyperparameters() == dict(
98-
num_samples_per_tree=str(NUM_SAMPLES_PER_TREE), num_trees=str(NUM_TREES),
102+
num_samples_per_tree=str(NUM_SAMPLES_PER_TREE),
103+
num_trees=str(NUM_TREES),
104+
eval_metrics='["accuracy", "precision_recall_fscore"]',
99105
)
100106

101107

@@ -104,7 +110,7 @@ def test_image(sagemaker_session):
104110
assert image_uris.retrieve("randomcutforest", REGION) == randomcutforest.train_image()
105111

106112

107-
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", [0])])
113+
@pytest.mark.parametrize("iterable_hyper_parameters, value", [("eval_metrics", 0)])
108114
def test_iterable_hyper_parameters_type(sagemaker_session, iterable_hyper_parameters, value):
109115
with pytest.raises(TypeError):
110116
test_params = ALL_REQ_ARGS.copy()

0 commit comments

Comments
 (0)