Skip to content

Commit 6834806

Browse files
author
Ignacio Quintero
committed
Enforce a HP Type when setting its value.
instead of validating with isinstance() cast the hp value to the type it is meant to be. This enforces a "strongly typed" value. When we deserialize from the API string responses it becomes easier to deal with too.
1 parent bd409b6 commit 6834806

File tree

6 files changed

+49
-48
lines changed

6 files changed

+49
-48
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
2828
"""Base class for Amazon first-party Estimator implementations. This class isn't intended
2929
to be instantiated directly."""
3030

31-
feature_dim = hp('feature_dim', (validation.isint, validation.gt(0)))
32-
mini_batch_size = hp('mini_batch_size', (validation.isint, validation.gt(0)))
31+
feature_dim = hp('feature_dim', validation.gt(0), data_type=int)
32+
mini_batch_size = hp('mini_batch_size', validation.gt(0), data_type=int)
3333

3434
def __init__(self, role, train_instance_count, train_instance_type, data_location=None, **kwargs):
3535
"""Initialize an AmazonAlgorithmEstimatorBase.

src/sagemaker/amazon/factorization_machines.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1414
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1515
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16-
from sagemaker.amazon.validation import gt, isin, isint, ge, isnumber
16+
from sagemaker.amazon.validation import gt, isin, ge
1717
from sagemaker.predictor import RealTimePredictor
1818
from sagemaker.model import Model
1919
from sagemaker.session import Session
@@ -23,34 +23,34 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase):
2323

2424
repo = 'factorization-machines:1'
2525

26-
num_factors = hp('num_factors', (gt(0), isint), 'An integer greater than zero', int)
26+
num_factors = hp('num_factors', gt(0), 'An integer greater than zero', int)
2727
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'),
2828
'Value "binary_classifier" or "regressor"', str)
29-
epochs = hp('epochs', (gt(0), isint), "An integer greater than 0", int)
30-
clip_gradient = hp('clip_gradient', isnumber, "A float value", float)
31-
eps = hp('eps', isnumber, "A float value", float)
32-
rescale_grad = hp('rescale_grad', isnumber, "A float value", float)
33-
bias_lr = hp('bias_lr', (ge(0), isnumber), "A non-negative float", float)
34-
linear_lr = hp('linear_lr', (ge(0), isnumber), "A non-negative float", float)
35-
factors_lr = hp('factors_lr', (ge(0), isnumber), "A non-negative float", float)
36-
bias_wd = hp('bias_wd', (ge(0), isnumber), "A non-negative float", float)
37-
linear_wd = hp('linear_wd', (ge(0), isnumber), "A non-negative float", float)
38-
factors_wd = hp('factors_wd', (ge(0), isnumber), "A non-negative float", float)
29+
epochs = hp('epochs', gt(0), "An integer greater than 0", int)
30+
clip_gradient = hp('clip_gradient', (), "A float value", float)
31+
eps = hp('eps', (), "A float value", float)
32+
rescale_grad = hp('rescale_grad', (), "A float value", float)
33+
bias_lr = hp('bias_lr', ge(0), "A non-negative float", float)
34+
linear_lr = hp('linear_lr', ge(0), "A non-negative float", float)
35+
factors_lr = hp('factors_lr', ge(0), "A non-negative float", float)
36+
bias_wd = hp('bias_wd', ge(0), "A non-negative float", float)
37+
linear_wd = hp('linear_wd', ge(0), "A non-negative float", float)
38+
factors_wd = hp('factors_wd', ge(0), "A non-negative float", float)
3939
bias_init_method = hp('bias_init_method', isin('normal', 'uniform', 'constant'),
4040
'Value "normal", "uniform" or "constant"', str)
41-
bias_init_scale = hp('bias_init_scale', (ge(0), isnumber), "A non-negative float", float)
42-
bias_init_sigma = hp('bias_init_sigma', (ge(0), isnumber), "A non-negative float", float)
43-
bias_init_value = hp('bias_init_value', isnumber, "A float value", float)
41+
bias_init_scale = hp('bias_init_scale', ge(0), "A non-negative float", float)
42+
bias_init_sigma = hp('bias_init_sigma', ge(0), "A non-negative float", float)
43+
bias_init_value = hp('bias_init_value', (), "A float value", float)
4444
linear_init_method = hp('linear_init_method', isin('normal', 'uniform', 'constant'),
4545
'Value "normal", "uniform" or "constant"', str)
46-
linear_init_scale = hp('linear_init_scale', (ge(0), isnumber), "A non-negative float", float)
47-
linear_init_sigma = hp('linear_init_sigma', (ge(0), isnumber), "A non-negative float", float)
48-
linear_init_value = hp('linear_init_value', isnumber, "A float value", float)
46+
linear_init_scale = hp('linear_init_scale', ge(0), "A non-negative float", float)
47+
linear_init_sigma = hp('linear_init_sigma', ge(0), "A non-negative float", float)
48+
linear_init_value = hp('linear_init_value', (), "A float value", float)
4949
factors_init_method = hp('factors_init_method', isin('normal', 'uniform', 'constant'),
5050
'Value "normal", "uniform" or "constant"', str)
51-
factors_init_scale = hp('factors_init_scale', (ge(0), isnumber), "A non-negative float", float)
52-
factors_init_sigma = hp('factors_init_sigma', (ge(0), isnumber), "A non-negative float", float)
53-
factors_init_value = hp('factors_init_value', isnumber, "A float value", float)
51+
factors_init_scale = hp('factors_init_scale', ge(0), "A non-negative float", float)
52+
factors_init_sigma = hp('factors_init_sigma', ge(0), "A non-negative float", float)
53+
factors_init_value = hp('factors_init_value', (), "A float value", float)
5454

5555
def __init__(self, role, train_instance_count, train_instance_type,
5656
num_factors, predictor_type,

src/sagemaker/amazon/hyperparameter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def __init__(self, name, validate=lambda _: True, validation_message="", data_ty
3636
def validate(self, value):
3737
if value is None: # We allow assignment from None, but Nones are not sent to training.
3838
return
39+
3940
for valid in self.validation:
4041
if not valid(value):
41-
error_message = "Invalid hyperparameter value {}".format(value)
42+
error_message = "Invalid hyperparameter value {} for {}".format(value, self.name)
4243
if self.validation_message:
4344
error_message = error_message + ". Expecting: " + self.validation_message
4445
raise ValueError(error_message)
@@ -51,6 +52,7 @@ def __get__(self, obj, objtype):
5152

5253
def __set__(self, obj, value):
5354
"""Validate the supplied value and set this hyperparameter to value"""
55+
value = self.data_type(value)
5456
self.validate(value)
5557
if '_hyperparameters' not in dir(obj):
5658
obj._hyperparameters = dict()

src/sagemaker/amazon/kmeans.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1414
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1515
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16-
from sagemaker.amazon.validation import gt, isin, isint, ge
16+
from sagemaker.amazon.validation import gt, isin, ge
1717
from sagemaker.predictor import RealTimePredictor
1818
from sagemaker.model import Model
1919
from sagemaker.session import Session
@@ -23,15 +23,15 @@ class KMeans(AmazonAlgorithmEstimatorBase):
2323

2424
repo = 'kmeans:1'
2525

26-
k = hp('k', (gt(1), isint), 'An integer greater-than 1', int)
26+
k = hp('k', gt(1), 'An integer greater-than 1', int)
2727
init_method = hp('init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str)
28-
max_iterations = hp('local_lloyd_max_iterations', (gt(0), isint), 'An integer greater-than 0', int)
29-
tol = hp('local_lloyd_tol', (gt(0), isint), 'An integer greater-than 0', int)
30-
num_trials = hp('local_lloyd_num_trials', (gt(0), isint), 'An integer greater-than 0', int)
28+
max_iterations = hp('local_lloyd_max_iterations', gt(0), 'An integer greater-than 0', int)
29+
tol = hp('local_lloyd_tol', gt(0), 'An integer greater-than 0', int)
30+
num_trials = hp('local_lloyd_num_trials', gt(0), 'An integer greater-than 0', int)
3131
local_init_method = hp('local_lloyd_init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str)
32-
half_life_time_size = hp('half_life_time_size', (ge(0), isint), 'An integer greater-than-or-equal-to 0', int)
33-
epochs = hp('epochs', (gt(0), isint), 'An integer greater-than 0', int)
34-
center_factor = hp('extra_center_factor', (gt(0), isint), 'An integer greater-than 0', int)
32+
half_life_time_size = hp('half_life_time_size', ge(0), 'An integer greater-than-or-equal-to 0', int)
33+
epochs = hp('epochs', gt(0), 'An integer greater-than 0', int)
34+
center_factor = hp('extra_center_factor', gt(0), 'An integer greater-than 0', int)
3535

3636
def __init__(self, role, train_instance_count, train_instance_type, k, init_method=None,
3737
max_iterations=None, tol=None, num_trials=None, local_init_method=None,

src/sagemaker/amazon/linear_learner.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1414
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1515
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16-
from sagemaker.amazon.validation import isin, gt, lt, isint, isbool, isnumber
16+
from sagemaker.amazon.validation import isin, gt, lt, isbool
1717
from sagemaker.predictor import RealTimePredictor
1818
from sagemaker.model import Model
1919
from sagemaker.session import Session
@@ -32,16 +32,16 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
3232
target_recall = hp('target_recall', (gt(0), lt(1)), "A float in (0,1)", float)
3333
target_precision = hp('target_precision', (gt(0), lt(1)), "A float in (0,1)", float)
3434
positive_example_weight_mult = hp('positive_example_weight_mult', gt(0), "A float greater than 0", float)
35-
epochs = hp('epochs', (gt(0), isint), "An integer greater-than 0", int)
35+
epochs = hp('epochs', gt(0), "An integer greater-than 0", int)
3636
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'),
3737
'One of "binary_classifier" or "regressor"', str)
3838
use_bias = hp('use_bias', isbool, "Either True or False", bool)
39-
num_models = hp('num_models', (gt(0), isint), "An integer greater-than 0", int)
40-
num_calibration_samples = hp('num_calibration_samples', (gt(0), isint), "An integer greater-than 0", int)
39+
num_models = hp('num_models', gt(0), "An integer greater-than 0", int)
40+
num_calibration_samples = hp('num_calibration_samples', gt(0), "An integer greater-than 0", int)
4141
init_method = hp('init_method', isin('uniform', 'normal'), 'One of "uniform" or "normal"', str)
4242
init_scale = hp('init_scale', (gt(-1), lt(1)), 'A float in (-1, 1)', float)
4343
init_sigma = hp('init_sigma', (gt(0), lt(1)), 'A float in (0, 1)', float)
44-
init_bias = hp('init_bias', isnumber, 'A number', float)
44+
init_bias = hp('init_bias', (), 'A number', float)
4545
optimizer = hp('optimizer', isin('sgd', 'adam', 'auto'), 'One of "sgd", "adam" or "auto', str)
4646
loss = hp('loss', isin('logistic', 'squared_loss', 'absolute_loss', 'auto'),
4747
'"logistic", "squared_loss", "absolute_loss" or"auto"', str)
@@ -53,15 +53,15 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
5353
beta_2 = hp('beta_1', (gt(0), lt(1)), 'A float in (0,1)', float)
5454
bias_lr_mult = hp('bias_lr_mult', gt(0), 'A float greater-than 0', float)
5555
bias_wd_mult = hp('bias_wd_mult', gt(0), 'A float greater-than 0', float)
56-
use_lr_scheduler = hp('use_lr_scheduler', isbool, 'A boolean', bool)
57-
lr_scheduler_step = hp('lr_scheduler_step', (gt(0), isint), 'An integer greater-than 0', int)
56+
use_lr_scheduler = hp('use_lr_scheduler', (), 'A boolean', bool)
57+
lr_scheduler_step = hp('lr_scheduler_step', gt(0), 'An integer greater-than 0', int)
5858
lr_scheduler_factor = hp('lr_scheduler_factor', (gt(0), lt(1)), 'A float in (0,1)', float)
5959
lr_scheduler_minimum_lr = hp('lr_scheduler_minimum_lr', gt(0), 'A float greater-than 0', float)
60-
normalize_data = hp('normalize_data', isbool, 'A boolean', bool)
61-
normalize_label = hp('normalize_label', isbool, 'A boolean', bool)
62-
unbias_data = hp('unbias_data', isbool, 'A boolean', bool)
63-
unbias_label = hp('unbias_label', isbool, 'A boolean', bool)
64-
num_point_for_scalar = hp('num_point_for_scalar', (isint, gt(0)), 'An integer greater-than 0', int)
60+
normalize_data = hp('normalize_data', (), 'A boolean', bool)
61+
normalize_label = hp('normalize_label', (), 'A boolean', bool)
62+
unbias_data = hp('unbias_data', (), 'A boolean', bool)
63+
unbias_label = hp('unbias_label', (), 'A boolean', bool)
64+
num_point_for_scalar = hp('num_point_for_scalar', gt(0), 'An integer greater-than 0', int)
6565

6666
def __init__(self, role, train_instance_count, train_instance_type, predictor_type='binary_classifier',
6767
binary_classifier_model_selection_criteria=None, target_recall=None, target_precision=None,

src/sagemaker/amazon/pca.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@ class PCA(AmazonAlgorithmEstimatorBase):
2424

2525
DEFAULT_MINI_BATCH_SIZE = 500
2626

27-
num_components = hp(name='num_components', validate=lambda x: x > 0 and isinstance(x, int),
27+
num_components = hp(name='num_components', validate=lambda x: x > 0,
2828
validation_message='Value must be an integer greater than zero', data_type=int)
2929
algorithm_mode = hp(name='algorithm_mode', validate=lambda x: x in ['regular', 'stable', 'randomized'],
3030
validation_message='Value must be one of "regular", "stable", "randomized"', data_type=str)
31-
subtract_mean = hp(name='subtract_mean', validate=lambda x: isinstance(x, bool),
32-
validation_message='Value must be a boolean', data_type=bool)
33-
extra_components = hp(name='extra_components', validate=lambda x: x >= 0 and isinstance(x, int),
31+
subtract_mean = hp(name='subtract_mean', validation_message='Value must be a boolean', data_type=bool)
32+
extra_components = hp(name='extra_components', validate=lambda x: x >= 0,
3433
validation_message="Value must be an integer greater than or equal to 0", data_type=int)
3534

3635
def __init__(self, role, train_instance_count, train_instance_type, num_components,

0 commit comments

Comments
 (0)