Skip to content

Commit 100d906

Browse files
author
Dan
authored
change: fix list serialization for 1P algos (#922)
1 parent 0cf5902 commit 100d906

File tree

5 files changed

+13
-3
lines changed

5 files changed

+13
-3
lines changed

src/sagemaker/amazon/hyperparameter.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import json
16+
1517

1618
class Hyperparameter(object):
1719
"""An algorithm hyperparameter with optional validation. Implemented as a python
@@ -67,4 +69,8 @@ def serialize_all(obj):
6769
"""Return all non-None ``hyperparameter`` values on ``obj`` as a ``dict[str,str].``"""
6870
if "_hyperparameters" not in dir(obj):
6971
return {}
70-
return {k: str(v) for k, v in obj._hyperparameters.items() if v is not None}
72+
return {
73+
k: json.dumps(v) if isinstance(v, list) else str(v)
74+
for k, v in obj._hyperparameters.items()
75+
if v is not None
76+
}

tests/integ/test_kmeans.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import gzip
16+
import json
1617
import os
1718
import pickle
1819
import sys
@@ -52,6 +53,7 @@ def test_kmeans(sagemaker_session):
5253
kmeans.half_life_time_size = 1
5354
kmeans.epochs = 1
5455
kmeans.center_factor = 1
56+
kmeans.eval_metrics = ["ssd", "msd"]
5557

5658
assert kmeans.hyperparameters() == dict(
5759
init_method=kmeans.init_method,
@@ -63,6 +65,7 @@ def test_kmeans(sagemaker_session):
6365
epochs=str(kmeans.epochs),
6466
extra_center_factor=str(kmeans.center_factor),
6567
k=str(kmeans.k),
68+
eval_metrics=json.dumps(kmeans.eval_metrics),
6669
force_dense="True",
6770
)
6871

tests/integ/test_randomcutforest.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_randomcutforest(sagemaker_session):
3434
train_instance_type="ml.c4.xlarge",
3535
num_trees=50,
3636
num_samples_per_tree=20,
37+
eval_metrics=["accuracy", "precision_recall_fscore"],
3738
sagemaker_session=sagemaker_session,
3839
)
3940

tests/unit/test_kmeans.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_all_hyperparameters(sagemaker_session):
104104
half_life_time_size="0",
105105
epochs="10",
106106
extra_center_factor="2",
107-
eval_metrics="['msd', 'ssd']",
107+
eval_metrics='["msd", "ssd"]',
108108
force_dense="True",
109109
)
110110

tests/unit/test_randomcutforest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_all_hyperparameters(sagemaker_session):
100100
assert randomcutforest.hyperparameters() == dict(
101101
num_samples_per_tree=str(NUM_SAMPLES_PER_TREE),
102102
num_trees=str(NUM_TREES),
103-
eval_metrics="{}".format(EVAL_METRICS),
103+
eval_metrics='["accuracy", "precision_recall_fscore"]',
104104
)
105105

106106

0 commit comments

Comments
 (0)