Skip to content

Commit afb3bbe

Browse files
Fix kmeans max_iterations hyperparameter (aws#199)
1 parent f132882 commit afb3bbe

File tree

4 files changed

+38
-7
lines changed

4 files changed

+38
-7
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
CHANGELOG
33
=========
44

5+
1.3.dev1
6+
========
7+
8+
* bug-fix: Estimators: Change max_iterations hyperparameter key for KMeans
9+
510
1.3.0
6-
=======
11+
=====
712

813
* feature: Add chainer
914

1015
1.2.5
11-
========
16+
=====
1217

1318
* bug-fix: Change module names to string type in __all__
1419
* feature: Save training output files in local mode

src/sagemaker/amazon/kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class KMeans(AmazonAlgorithmEstimatorBase):
2828

2929
k = hp('k', gt(1), 'An integer greater-than 1', int)
3030
init_method = hp('init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str)
31-
max_iterations = hp('local_lloyd_max_iterations', gt(0), 'An integer greater-than 0', int)
31+
max_iterations = hp('local_lloyd_max_iter', gt(0), 'An integer greater-than 0', int)
3232
tol = hp('local_lloyd_tol', (ge(0), le(1)), 'An float in [0, 1]', float)
3333
num_trials = hp('local_lloyd_num_trials', gt(0), 'An integer greater-than 0', int)
3434
local_init_method = hp('local_lloyd_init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str)

tests/integ/test_kmeans.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,27 @@ def test_kmeans(sagemaker_session):
4141
k=10, sagemaker_session=sagemaker_session, base_job_name='test-kmeans')
4242

4343
kmeans.init_method = 'random'
44-
kmeans.max_iterators = 1
44+
kmeans.max_iterations = 1
4545
kmeans.tol = 1
4646
kmeans.num_trials = 1
4747
kmeans.local_init_method = 'kmeans++'
4848
kmeans.half_life_time_size = 1
4949
kmeans.epochs = 1
5050
kmeans.center_factor = 1
5151

52+
assert kmeans.hyperparameters() == dict(
53+
init_method=kmeans.init_method,
54+
local_lloyd_max_iter=str(kmeans.max_iterations),
55+
local_lloyd_tol=str(kmeans.tol),
56+
local_lloyd_num_trials=str(kmeans.num_trials),
57+
local_lloyd_init_method=kmeans.local_init_method,
58+
half_life_time_size=str(kmeans.half_life_time_size),
59+
epochs=str(kmeans.epochs),
60+
extra_center_factor=str(kmeans.center_factor),
61+
k=str(kmeans.k),
62+
force_dense='True',
63+
)
64+
5265
kmeans.fit(kmeans.record_set(train_set[0][:100]))
5366

5467
endpoint_name = name_from_base('kmeans')
@@ -80,14 +93,27 @@ def test_async_kmeans(sagemaker_session):
8093
k=10, sagemaker_session=sagemaker_session, base_job_name='test-kmeans')
8194

8295
kmeans.init_method = 'random'
83-
kmeans.max_iterators = 1
96+
kmeans.max_iterations = 1
8497
kmeans.tol = 1
8598
kmeans.num_trials = 1
8699
kmeans.local_init_method = 'kmeans++'
87100
kmeans.half_life_time_size = 1
88101
kmeans.epochs = 1
89102
kmeans.center_factor = 1
90103

104+
assert kmeans.hyperparameters() == dict(
105+
init_method=kmeans.init_method,
106+
local_lloyd_max_iter=str(kmeans.max_iterations),
107+
local_lloyd_tol=str(kmeans.tol),
108+
local_lloyd_num_trials=str(kmeans.num_trials),
109+
local_lloyd_init_method=kmeans.local_init_method,
110+
half_life_time_size=str(kmeans.half_life_time_size),
111+
epochs=str(kmeans.epochs),
112+
extra_center_factor=str(kmeans.center_factor),
113+
k=str(kmeans.k),
114+
force_dense='True',
115+
)
116+
91117
kmeans.fit(kmeans.record_set(train_set[0][:100]), wait=False)
92118
training_job_name = kmeans.latest_training_job.name
93119

tests/unit/test_kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ def test_all_hyperparameters(sagemaker_session):
7474
assert kmeans.hyperparameters() == dict(
7575
k=str(ALL_REQ_ARGS['k']),
7676
init_method='random',
77-
local_lloyd_max_iterations='3',
77+
local_lloyd_max_iter='3',
7878
local_lloyd_tol='0.5',
7979
local_lloyd_num_trials='5',
8080
local_lloyd_init_method='kmeans++',
8181
half_life_time_size='0',
8282
epochs='10',
8383
extra_center_factor='2',
8484
eval_metrics='[\'msd\', \'ssd\']',
85-
force_dense='True'
85+
force_dense='True',
8686
)
8787

8888

0 commit comments

Comments
 (0)