Skip to content

Commit 81f90a5

Browse files
author
Vikram Rajasekaran
committed
Added scaling_type to Integer and Continuous ranges
1 parent a504db4 commit 81f90a5

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/sagemaker/parameter.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ParameterRange(object):
2525

2626
__all_types__ = ('Continuous', 'Categorical', 'Integer')
2727

28-
def __init__(self, min_value, max_value):
28+
def __init__(self, min_value, max_value, scaling_type=None):
2929
"""Initialize a parameter range.
3030
3131
Args:
@@ -34,6 +34,7 @@ def __init__(self, min_value, max_value):
3434
"""
3535
self.min_value = min_value
3636
self.max_value = max_value
37+
self.scaling_type = scaling_type
3738

3839
def is_valid(self, value):
3940
"""Determine if a value is valid within this ParameterRange.
@@ -60,9 +61,16 @@ def as_tuning_range(self, name):
6061
Returns:
6162
dict[str, str]: A dictionary that contains the name and values of the hyperparameter.
6263
"""
63-
return {'Name': name,
64-
'MinValue': to_str(self.min_value),
65-
'MaxValue': to_str(self.max_value)}
64+
tuning_range = {
65+
'Name': name,
66+
'MinValue': to_str(self.min_value),
67+
'MaxValue': to_str(self.max_value)
68+
}
69+
70+
if self.scaling_type is not None:
71+
tuning_range['ScalingType'] = self.scaling_type
72+
73+
return tuning_range
6674

6775

6876
class ContinuousParameter(ParameterRange):

tests/unit/test_tuner.py

+12
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,12 @@ def test_continuous_parameter_ranges():
637637
assert ranges['MaxValue'] == '0.01'
638638

639639

640+
def test_continuous_parameter_scaling_type():
641+
cont_param = ContinuousParameter(0.1, 2, scaling_type='ReverseLogarithmic')
642+
cont_range = cont_param.as_tuning_range('range')
643+
assert cont_range['ScalingType'] == 'ReverseLogarithmic'
644+
645+
640646
def test_integer_parameter():
641647
int_param = IntegerParameter(1, 2)
642648
assert isinstance(int_param, ParameterRange)
@@ -652,6 +658,12 @@ def test_integer_parameter_ranges():
652658
assert ranges['MaxValue'] == '2'
653659

654660

661+
def test_integer_parameter_scaling_type():
662+
int_param = IntegerParameter(2, 3, scaling_type='Auto')
663+
int_range = int_param.as_tuning_range('range')
664+
assert int_range['ScalingType'] == 'Auto'
665+
666+
655667
def test_categorical_parameter_list():
656668
cat_param = CategoricalParameter(['a', 'z'])
657669
assert isinstance(cat_param, ParameterRange)

0 commit comments

Comments
 (0)