13
13
from __future__ import absolute_import
14
14
15
15
import inspect
16
+ import json
16
17
18
+ from sagemaker .estimator import Framework
17
19
from sagemaker .job import _Job
18
20
from sagemaker .utils import base_name_from_image , name_from_base
19
21
@@ -51,6 +53,9 @@ def as_tuning_range(self, name):
51
53
return {'Name' : name ,
52
54
'Values' : self .values }
53
55
56
+ def as_json_range (self , name ):
57
+ return {'Name' : name , 'Values' : [json .dumps (v ) for v in self .values ]}
58
+
54
59
55
60
class IntegerParameter (_ParameterRange ):
56
61
__name__ = 'Integer'
@@ -60,31 +65,58 @@ def __init__(self, min_value, max_value):
60
65
61
66
62
67
class HyperparameterTuner (object ):
63
- __objectives__ = ['Minimize' , 'Maximize' ]
68
+ SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
69
+ SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
64
70
65
71
def __init__ (self , estimator , objective_metric_name , hyperparameter_ranges , metric_definitions , strategy = 'Bayesian' ,
66
72
objective_type = 'Maximize' , max_jobs = 1 , max_parallel_jobs = 1 , base_tuning_job_name = None ):
67
- if objective_type not in HyperparameterTuner .__objectives__ :
68
- raise ValueError ("Unsupported 'objective' values" )
73
+ self ._hyperparameter_ranges = hyperparameter_ranges
74
+ if self ._hyperparameter_ranges is None or len (self ._hyperparameter_ranges ) == 0 :
75
+ raise ValueError ('Need to specify hyperparameter ranges' )
69
76
70
77
self .estimator = estimator
71
78
self .objective_metric_name = objective_metric_name
72
- self ._hyperparameter_ranges = hyperparameter_ranges
79
+ self .metric_definitions = metric_definitions
80
+
73
81
self .strategy = strategy
74
82
self .objective_type = objective_type
83
+
75
84
self .max_jobs = max_jobs
76
85
self .max_parallel_jobs = max_parallel_jobs
77
86
self .tuning_job_name = base_tuning_job_name
78
87
self .metric_definitions = metric_definitions
79
88
self .latest_tuning_job = None
80
89
self ._validate_parameter_ranges ()
81
90
82
- def fit (self , inputs ):
83
- """Create tuning job
91
+ def prepare_for_training (self ):
92
+ # TODO: Change this so that it can handle unicode in Python 2
93
+ self .static_hyperparameters = {str (k ): str (v ) for (k , v ) in self .estimator .hyperparameters ().items ()}
94
+ for hyperparameter_name in self ._hyperparameter_ranges .keys ():
95
+ self .static_hyperparameters .pop (hyperparameter_name , None )
96
+
97
+ # For attach() to know what estimator to use
98
+ self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_CLASS_NAME ] = self .estimator .__class__ .__name__
99
+ self .static_hyperparameters [self .SAGEMAKER_ESTIMATOR_MODULE ] = self .estimator .__module__
100
+
101
+ def fit (self , inputs , job_name = None , ** kwargs ):
102
+ """Start a hyperparameter tuning job.
84
103
85
104
Args:
86
- inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
105
+ inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
106
+ job_name (str): Job name
107
+ **kwargs: Other arguments
87
108
"""
109
+ # TODO: I think I have to move RecordSet to its own file
110
+ from sagemaker .amazon .amazon_estimator import RecordSet
111
+
112
+ # 1P estimators require a RecordSet object
113
+ if isinstance (inputs , RecordSet ):
114
+ self .estimator .prepare_for_training (inputs , ** kwargs )
115
+ inputs = inputs .data_channel ()
116
+ else :
117
+ self .estimator .prepare_for_training (** kwargs )
118
+
119
+ self .prepare_for_training ()
88
120
self .latest_tuning_job = _TuningJob .start_new (self , inputs )
89
121
90
122
def stop_tuning_job (self ):
@@ -101,15 +133,20 @@ def hyperparameter_ranges(self):
101
133
"""Return collections of ``ParameterRanges``
102
134
103
135
Returns:
104
- dict: ParameterRanges suitable for tuning job.
136
+ dict: ParameterRanges suitable for a hyperparameter tuning job.
105
137
"""
106
138
hyperparameter_ranges = dict ()
107
139
for range_type in _ParameterRange .__all_types__ :
108
- parameter_range = []
140
+ parameter_ranges = []
109
141
for parameter_name , parameter in self ._hyperparameter_ranges .items ():
110
142
if parameter is not None and parameter .__name__ == range_type :
111
- parameter_range .append (parameter .as_tuning_range (parameter_name ))
112
- hyperparameter_ranges [range_type + 'ParameterRanges' ] = parameter_range
143
+ # Categorical parameters needed to be serialized as JSON for our framework containers
144
+ if isinstance (parameter , CategoricalParameter ) and isinstance (self .estimator , Framework ):
145
+ tuning_range = parameter .as_json_range (parameter_name )
146
+ else :
147
+ tuning_range = parameter .as_tuning_range (parameter_name )
148
+ parameter_ranges .append (tuning_range )
149
+ hyperparameter_ranges [range_type + 'ParameterRanges' ] = parameter_ranges
113
150
return hyperparameter_ranges
114
151
115
152
def _validate_parameter_ranges (self ):
@@ -138,43 +175,38 @@ def _validate_parameter_ranges(self):
138
175
139
176
140
177
class _TuningJob (_Job ):
141
- SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
142
- SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
143
-
144
178
def __init__ (self , sagemaker_session , tuning_job_name ):
145
179
super (_TuningJob , self ).__init__ (sagemaker_session , tuning_job_name )
146
180
147
181
@classmethod
148
182
def start_new (cls , tuner , inputs ):
149
- """Create a new Amazon SageMaker tuning job from the HyperparameterTuner.
183
+ """Create a new Amazon SageMaker hyperparameter tuning job from the HyperparameterTuner.
150
184
151
185
Args:
152
- tuner (sagemaker.tuner.HyperparameterTuner): Tuner object created by the user.
153
- inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
186
+ tuner (sagemaker.tuner.HyperparameterTuner): HyperparameterTuner object created by the user.
187
+ inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
154
188
155
189
Returns:
156
190
sagemaker.tuner._TuningJob: Constructed object that captures all information about the started job.
157
191
"""
158
192
config = _Job ._load_config (inputs , tuner .estimator )
159
193
160
- static_hyperparameters = {str (k ): str (v ) for (k , v ) in tuner .estimator .hyperparameters ().items ()}
161
- for hyperparameter_name in tuner ._hyperparameter_ranges .keys ():
162
- static_hyperparameters .pop (hyperparameter_name , None )
163
-
164
- static_hyperparameters [cls .SAGEMAKER_ESTIMATOR_CLASS_NAME ] = tuner .estimator .__class__ .__name__
165
- static_hyperparameters [cls .SAGEMAKER_ESTIMATOR_MODULE ] = tuner .estimator .__module__
166
-
167
194
base_name = tuner .estimator .base_job_name or base_name_from_image (tuner .estimator .train_image ())
168
195
tuning_job_name = name_from_base (base_name )
169
196
197
+ # TODO: Update name generation so that the base name isn't limited to so few characters
198
+ if len (tuning_job_name ) > 32 :
199
+ raise ValueError ('Tuning job name too long - must be 32 characters or fewer: {}' .format (tuning_job_name ))
200
+
170
201
tuner .estimator .sagemaker_session .tune (job_name = tuning_job_name , strategy = tuner .strategy ,
171
- objective = tuner .objective_type , metric_name = tuner .objective_metric_name ,
202
+ objective_type = tuner .objective_type ,
203
+ objective_metric_name = tuner .objective_metric_name ,
172
204
max_jobs = tuner .max_jobs , max_parallel_jobs = tuner .max_parallel_jobs ,
173
205
parameter_ranges = tuner .hyperparameter_ranges (),
174
- static_hp = static_hyperparameters ,
206
+ static_hyperparameters = tuner . static_hyperparameters ,
175
207
image = tuner .estimator .train_image (),
176
208
input_mode = tuner .estimator .input_mode ,
177
- metric_definitions = tuner .estimator . metric_definitions ,
209
+ metric_definitions = tuner .metric_definitions ,
178
210
role = (config ['role' ]), input_config = (config ['input_config' ]),
179
211
output_config = (config ['output_config' ]),
180
212
resource_config = (config ['resource_config' ]),
0 commit comments