Skip to content

Commit 54b3830

Browse files
authored
Add data_type to hyperparameters (#54)
When we describe a training job the data type of the hyper parameters is lost because we use a dict[str, str]. This adds a new field to Hyperparameter so that we can convert the datatypes at runtime. instead of validating with isinstance(), we cast the hp value to the type it is meant to be. This enforces a "strongly typed" value. When we deserialize from the API string responses it becomes easier to deal with too.
1 parent e82fb4f commit 54b3830

File tree

8 files changed

+101
-87
lines changed

8 files changed

+101
-87
lines changed

src/sagemaker/amazon/amazon_estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
2828
"""Base class for Amazon first-party Estimator implementations. This class isn't intended
2929
to be instantiated directly."""
3030

31-
feature_dim = hp('feature_dim', (validation.isint, validation.gt(0)))
32-
mini_batch_size = hp('mini_batch_size', (validation.isint, validation.gt(0)))
31+
feature_dim = hp('feature_dim', validation.gt(0), data_type=int)
32+
mini_batch_size = hp('mini_batch_size', validation.gt(0), data_type=int)
3333

3434
def __init__(self, role, train_instance_count, train_instance_type, data_location=None, **kwargs):
3535
"""Initialize an AmazonAlgorithmEstimatorBase.

src/sagemaker/amazon/factorization_machines.py

+25-25
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1414
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1515
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16-
from sagemaker.amazon.validation import gt, isin, isint, ge, isnumber
16+
from sagemaker.amazon.validation import gt, isin, ge
1717
from sagemaker.predictor import RealTimePredictor
1818
from sagemaker.model import Model
1919
from sagemaker.session import Session
@@ -23,34 +23,34 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase):
2323

2424
repo = 'factorization-machines:1'
2525

26-
num_factors = hp('num_factors', (gt(0), isint), 'An integer greater than zero')
26+
num_factors = hp('num_factors', gt(0), 'An integer greater than zero', int)
2727
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'),
28-
'Value "binary_classifier" or "regressor"')
29-
epochs = hp('epochs', (gt(0), isint), "An integer greater than 0")
30-
clip_gradient = hp('clip_gradient', isnumber, "A float value")
31-
eps = hp('eps', isnumber, "A float value")
32-
rescale_grad = hp('rescale_grad', isnumber, "A float value")
33-
bias_lr = hp('bias_lr', (ge(0), isnumber), "A non-negative float")
34-
linear_lr = hp('linear_lr', (ge(0), isnumber), "A non-negative float")
35-
factors_lr = hp('factors_lr', (ge(0), isnumber), "A non-negative float")
36-
bias_wd = hp('bias_wd', (ge(0), isnumber), "A non-negative float")
37-
linear_wd = hp('linear_wd', (ge(0), isnumber), "A non-negative float")
38-
factors_wd = hp('factors_wd', (ge(0), isnumber), "A non-negative float")
28+
'Value "binary_classifier" or "regressor"', str)
29+
epochs = hp('epochs', gt(0), "An integer greater than 0", int)
30+
clip_gradient = hp('clip_gradient', (), "A float value", float)
31+
eps = hp('eps', (), "A float value", float)
32+
rescale_grad = hp('rescale_grad', (), "A float value", float)
33+
bias_lr = hp('bias_lr', ge(0), "A non-negative float", float)
34+
linear_lr = hp('linear_lr', ge(0), "A non-negative float", float)
35+
factors_lr = hp('factors_lr', ge(0), "A non-negative float", float)
36+
bias_wd = hp('bias_wd', ge(0), "A non-negative float", float)
37+
linear_wd = hp('linear_wd', ge(0), "A non-negative float", float)
38+
factors_wd = hp('factors_wd', ge(0), "A non-negative float", float)
3939
bias_init_method = hp('bias_init_method', isin('normal', 'uniform', 'constant'),
40-
'Value "normal", "uniform" or "constant"')
41-
bias_init_scale = hp('bias_init_scale', (ge(0), isnumber), "A non-negative float")
42-
bias_init_sigma = hp('bias_init_sigma', (ge(0), isnumber), "A non-negative float")
43-
bias_init_value = hp('bias_init_value', isnumber, "A float value")
40+
'Value "normal", "uniform" or "constant"', str)
41+
bias_init_scale = hp('bias_init_scale', ge(0), "A non-negative float", float)
42+
bias_init_sigma = hp('bias_init_sigma', ge(0), "A non-negative float", float)
43+
bias_init_value = hp('bias_init_value', (), "A float value", float)
4444
linear_init_method = hp('linear_init_method', isin('normal', 'uniform', 'constant'),
45-
'Value "normal", "uniform" or "constant"')
46-
linear_init_scale = hp('linear_init_scale', (ge(0), isnumber), "A non-negative float")
47-
linear_init_sigma = hp('linear_init_sigma', (ge(0), isnumber), "A non-negative float")
48-
linear_init_value = hp('linear_init_value', isnumber, "A float value")
45+
'Value "normal", "uniform" or "constant"', str)
46+
linear_init_scale = hp('linear_init_scale', ge(0), "A non-negative float", float)
47+
linear_init_sigma = hp('linear_init_sigma', ge(0), "A non-negative float", float)
48+
linear_init_value = hp('linear_init_value', (), "A float value", float)
4949
factors_init_method = hp('factors_init_method', isin('normal', 'uniform', 'constant'),
50-
'Value "normal", "uniform" or "constant"')
51-
factors_init_scale = hp('factors_init_scale', (ge(0), isnumber), "A non-negative float")
52-
factors_init_sigma = hp('factors_init_sigma', (ge(0), isnumber), "A non-negative float")
53-
factors_init_value = hp('factors_init_value', isnumber, "A float value")
50+
'Value "normal", "uniform" or "constant"', str)
51+
factors_init_scale = hp('factors_init_scale', ge(0), "A non-negative float", float)
52+
factors_init_sigma = hp('factors_init_sigma', ge(0), "A non-negative float", float)
53+
factors_init_value = hp('factors_init_value', (), "A float value", float)
5454

5555
def __init__(self, role, train_instance_count, train_instance_type,
5656
num_factors, predictor_type,

src/sagemaker/amazon/hyperparameter.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Hyperparameter(object):
1616
"""An algorithm hyperparameter with optional validation. Implemented as a python
1717
descriptor object."""
1818

19-
def __init__(self, name, validate=lambda _: True, validation_message=""):
19+
def __init__(self, name, validate=lambda _: True, validation_message="", data_type=str):
2020
"""Args:
2121
name (str): The name of this hyperparameter
2222
validate (callable[object]->[bool]): A validation function or list of validation functions.
@@ -27,6 +27,7 @@ def __init__(self, name, validate=lambda _: True, validation_message=""):
2727
self.validation = validate
2828
self.validation_message = validation_message
2929
self.name = name
30+
self.data_type = data_type
3031
try:
3132
iter(self.validation)
3233
except TypeError:
@@ -35,9 +36,10 @@ def __init__(self, name, validate=lambda _: True, validation_message=""):
3536
def validate(self, value):
3637
if value is None: # We allow assignment from None, but Nones are not sent to training.
3738
return
39+
3840
for valid in self.validation:
3941
if not valid(value):
40-
error_message = "Invalid hyperparameter value {}".format(value)
42+
error_message = "Invalid hyperparameter value {} for {}".format(value, self.name)
4143
if self.validation_message:
4244
error_message = error_message + ". Expecting: " + self.validation_message
4345
raise ValueError(error_message)
@@ -50,6 +52,7 @@ def __get__(self, obj, objtype):
5052

5153
def __set__(self, obj, value):
5254
"""Validate the supplied value and set this hyperparameter to value"""
55+
value = None if value is None else self.data_type(value)
5356
self.validate(value)
5457
if '_hyperparameters' not in dir(obj):
5558
obj._hyperparameters = dict()

src/sagemaker/amazon/kmeans.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1414
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1515
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16-
from sagemaker.amazon.validation import gt, isin, isint, ge
16+
from sagemaker.amazon.validation import gt, isin, ge
1717
from sagemaker.predictor import RealTimePredictor
1818
from sagemaker.model import Model
1919
from sagemaker.session import Session
@@ -23,15 +23,15 @@ class KMeans(AmazonAlgorithmEstimatorBase):
2323

2424
repo = 'kmeans:1'
2525

26-
k = hp('k', (gt(1), isint), 'An integer greater-than 1')
27-
init_method = hp('init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"')
28-
max_iterations = hp('local_lloyd_max_iterations', (gt(0), isint), 'An integer greater-than 0')
29-
tol = hp('local_lloyd_tol', (gt(0), isint), 'An integer greater-than 0')
30-
num_trials = hp('local_lloyd_num_trials', (gt(0), isint), 'An integer greater-than 0')
31-
local_init_method = hp('local_lloyd_init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"')
32-
half_life_time_size = hp('half_life_time_size', (ge(0), isint), 'An integer greater-than-or-equal-to 0')
33-
epochs = hp('epochs', (gt(0), isint), 'An integer greater-than 0')
34-
center_factor = hp('extra_center_factor', (gt(0), isint), 'An integer greater-than 0')
26+
k = hp('k', gt(1), 'An integer greater-than 1', int)
27+
init_method = hp('init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str)
28+
max_iterations = hp('local_lloyd_max_iterations', gt(0), 'An integer greater-than 0', int)
29+
tol = hp('local_lloyd_tol', gt(0), 'An integer greater-than 0', int)
30+
num_trials = hp('local_lloyd_num_trials', gt(0), 'An integer greater-than 0', int)
31+
local_init_method = hp('local_lloyd_init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str)
32+
half_life_time_size = hp('half_life_time_size', ge(0), 'An integer greater-than-or-equal-to 0', int)
33+
epochs = hp('epochs', gt(0), 'An integer greater-than 0', int)
34+
center_factor = hp('extra_center_factor', gt(0), 'An integer greater-than 0', int)
3535

3636
def __init__(self, role, train_instance_count, train_instance_type, k, init_method=None,
3737
max_iterations=None, tol=None, num_trials=None, local_init_method=None,

src/sagemaker/amazon/linear_learner.py

+34-33
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
1414
from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer
1515
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
16-
from sagemaker.amazon.validation import isin, gt, lt, isint, isbool, isnumber
16+
from sagemaker.amazon.validation import isin, gt, lt
1717
from sagemaker.predictor import RealTimePredictor
1818
from sagemaker.model import Model
1919
from sagemaker.session import Session
@@ -27,40 +27,41 @@ class LinearLearner(AmazonAlgorithmEstimatorBase):
2727

2828
binary_classifier_model_selection_criteria = hp('binary_classifier_model_selection_criteria',
2929
isin('accuracy', 'f1', 'precision_at_target_recall',
30-
'recall_at_target_precision', 'cross_entropy_loss'))
31-
target_recall = hp('target_recall', (gt(0), lt(1)), "A float in (0,1)")
32-
target_precision = hp('target_precision', (gt(0), lt(1)), "A float in (0,1)")
33-
positive_example_weight_mult = hp('positive_example_weight_mult', gt(0), "A float greater than 0")
34-
epochs = hp('epochs', (gt(0), isint), "An integer greater-than 0")
30+
'recall_at_target_precision', 'cross_entropy_loss'),
31+
data_type=str)
32+
target_recall = hp('target_recall', (gt(0), lt(1)), "A float in (0,1)", float)
33+
target_precision = hp('target_precision', (gt(0), lt(1)), "A float in (0,1)", float)
34+
positive_example_weight_mult = hp('positive_example_weight_mult', gt(0), "A float greater than 0", float)
35+
epochs = hp('epochs', gt(0), "An integer greater-than 0", int)
3536
predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'),
36-
'One of "binary_classifier" or "regressor"')
37-
use_bias = hp('use_bias', isbool, "Either True or False")
38-
num_models = hp('num_models', (gt(0), isint), "An integer greater-than 0")
39-
num_calibration_samples = hp('num_calibration_samples', (gt(0), isint), "An integer greater-than 0")
40-
init_method = hp('init_method', isin('uniform', 'normal'), 'One of "uniform" or "normal"')
41-
init_scale = hp('init_scale', (gt(-1), lt(1)), 'A float in (-1, 1)')
42-
init_sigma = hp('init_sigma', (gt(0), lt(1)), 'A float in (0, 1)')
43-
init_bias = hp('init_bias', isnumber, 'A number')
44-
optimizer = hp('optimizer', isin('sgd', 'adam', 'auto'), 'One of "sgd", "adam" or "auto')
37+
'One of "binary_classifier" or "regressor"', str)
38+
use_bias = hp('use_bias', (), "Either True or False", bool)
39+
num_models = hp('num_models', gt(0), "An integer greater-than 0", int)
40+
num_calibration_samples = hp('num_calibration_samples', gt(0), "An integer greater-than 0", int)
41+
init_method = hp('init_method', isin('uniform', 'normal'), 'One of "uniform" or "normal"', str)
42+
init_scale = hp('init_scale', (gt(-1), lt(1)), 'A float in (-1, 1)', float)
43+
init_sigma = hp('init_sigma', (gt(0), lt(1)), 'A float in (0, 1)', float)
44+
init_bias = hp('init_bias', (), 'A number', float)
45+
optimizer = hp('optimizer', isin('sgd', 'adam', 'auto'), 'One of "sgd", "adam" or "auto', str)
4546
loss = hp('loss', isin('logistic', 'squared_loss', 'absolute_loss', 'auto'),
46-
'"logistic", "squared_loss", "absolute_loss" or"auto"')
47-
wd = hp('wd', (gt(0), lt(1)), 'A float in (0,1)')
48-
l1 = hp('l1', (gt(0), lt(1)), 'A float in (0,1)')
49-
momentum = hp('momentum', (gt(0), lt(1)), 'A float in (0,1)')
50-
learning_rate = hp('learning_rate', (gt(0), lt(1)), 'A float in (0,1)')
51-
beta_1 = hp('beta_1', (gt(0), lt(1)), 'A float in (0,1)')
52-
beta_2 = hp('beta_1', (gt(0), lt(1)), 'A float in (0,1)')
53-
bias_lr_mult = hp('bias_lr_mult', gt(0), 'A float greater-than 0')
54-
bias_wd_mult = hp('bias_wd_mult', gt(0), 'A float greater-than 0')
55-
use_lr_scheduler = hp('use_lr_scheduler', isbool, 'A boolean')
56-
lr_scheduler_step = hp('lr_scheduler_step', (gt(0), isint), 'An integer greater-than 0')
57-
lr_scheduler_factor = hp('lr_scheduler_factor', (gt(0), lt(1)), 'A float in (0,1)')
58-
lr_scheduler_minimum_lr = hp('lr_scheduler_minimum_lr', gt(0), 'A float greater-than 0')
59-
normalize_data = hp('normalize_data', isbool, 'A boolean')
60-
normalize_label = hp('normalize_label', isbool, 'A boolean')
61-
unbias_data = hp('unbias_data', isbool, 'A boolean')
62-
unbias_label = hp('unbias_label', isbool, 'A boolean')
63-
num_point_for_scalar = hp('num_point_for_scalar', (isint, gt(0)), 'An integer greater-than 0')
47+
'"logistic", "squared_loss", "absolute_loss" or"auto"', str)
48+
wd = hp('wd', (gt(0), lt(1)), 'A float in (0,1)', float)
49+
l1 = hp('l1', (gt(0), lt(1)), 'A float in (0,1)', float)
50+
momentum = hp('momentum', (gt(0), lt(1)), 'A float in (0,1)', float)
51+
learning_rate = hp('learning_rate', (gt(0), lt(1)), 'A float in (0,1)', float)
52+
beta_1 = hp('beta_1', (gt(0), lt(1)), 'A float in (0,1)', float)
53+
beta_2 = hp('beta_1', (gt(0), lt(1)), 'A float in (0,1)', float)
54+
bias_lr_mult = hp('bias_lr_mult', gt(0), 'A float greater-than 0', float)
55+
bias_wd_mult = hp('bias_wd_mult', gt(0), 'A float greater-than 0', float)
56+
use_lr_scheduler = hp('use_lr_scheduler', (), 'A boolean', bool)
57+
lr_scheduler_step = hp('lr_scheduler_step', gt(0), 'An integer greater-than 0', int)
58+
lr_scheduler_factor = hp('lr_scheduler_factor', (gt(0), lt(1)), 'A float in (0,1)', float)
59+
lr_scheduler_minimum_lr = hp('lr_scheduler_minimum_lr', gt(0), 'A float greater-than 0', float)
60+
normalize_data = hp('normalize_data', (), 'A boolean', bool)
61+
normalize_label = hp('normalize_label', (), 'A boolean', bool)
62+
unbias_data = hp('unbias_data', (), 'A boolean', bool)
63+
unbias_label = hp('unbias_label', (), 'A boolean', bool)
64+
num_point_for_scalar = hp('num_point_for_scalar', gt(0), 'An integer greater-than 0', int)
6465

6566
def __init__(self, role, train_instance_count, train_instance_type, predictor_type='binary_classifier',
6667
binary_classifier_model_selection_criteria=None, target_recall=None, target_precision=None,

src/sagemaker/amazon/pca.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,13 @@ class PCA(AmazonAlgorithmEstimatorBase):
2424

2525
DEFAULT_MINI_BATCH_SIZE = 500
2626

27-
num_components = hp(name='num_components', validate=lambda x: x > 0 and isinstance(x, int),
28-
validation_message='Value must be an integer greater than zero')
27+
num_components = hp(name='num_components', validate=lambda x: x > 0,
28+
validation_message='Value must be an integer greater than zero', data_type=int)
2929
algorithm_mode = hp(name='algorithm_mode', validate=lambda x: x in ['regular', 'stable', 'randomized'],
30-
validation_message='Value must be one of "regular", "stable", "randomized"')
31-
subtract_mean = hp(name='subtract_mean', validate=lambda x: isinstance(x, bool),
32-
validation_message='Value must be a boolean')
33-
extra_components = hp(name='extra_components', validate=lambda x: x >= 0 and isinstance(x, int),
34-
validation_message="Value must be an integer greater than or equal to 0")
30+
validation_message='Value must be one of "regular", "stable", "randomized"', data_type=str)
31+
subtract_mean = hp(name='subtract_mean', validation_message='Value must be a boolean', data_type=bool)
32+
extra_components = hp(name='extra_components', validate=lambda x: x >= 0,
33+
validation_message="Value must be an integer greater than or equal to 0", data_type=int)
3534

3635
def __init__(self, role, train_instance_count, train_instance_type, num_components,
3736
algorithm_mode=None, subtract_mean=None, extra_components=None, **kwargs):

src/sagemaker/amazon/validation.py

-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
import numbers
1413

1514

1615
def gt(minimum):
@@ -41,8 +40,3 @@ def istype(expected):
4140
def validate(value):
4241
return isinstance(value, expected)
4342
return validate
44-
45-
46-
isint = istype(int)
47-
isbool = istype(bool)
48-
isnumber = istype(numbers.Number) # noqa

tests/unit/test_hyperparameter.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
class Test(object):
1818

19-
blank = Hyperparameter(name="some-name")
19+
blank = Hyperparameter(name="some-name", data_type=int)
2020
elizabeth = Hyperparameter(name='elizabeth')
21-
validated = Hyperparameter(name="validated", validate=lambda value: value > 55)
21+
validated = Hyperparameter(name="validated", validate=lambda value: value > 55, data_type=int)
2222

2323

2424
def test_blank_access():
@@ -55,3 +55,20 @@ def test_validated():
5555
x.validated = 66
5656
with pytest.raises(ValueError):
5757
x.validated = 23
58+
59+
60+
def test_data_type():
61+
x = Test()
62+
x.validated = 66
63+
assert type(x.validated) == Test.__dict__["validated"].data_type
64+
65+
66+
def test_from_string():
67+
x = Test()
68+
value = 65
69+
70+
x.validated = value
71+
from_api = str(value)
72+
73+
x.validated = from_api
74+
assert x.validated == value

0 commit comments

Comments
 (0)