Skip to content

Commit 6dc7905

Browse files
authored
Refactor the rest of the 1P estimators (aws#24)
1 parent f29cf47 commit 6dc7905

12 files changed

+58
-61
lines changed

src/sagemaker/amazon/kmeans.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def create_model(self):
133133
s3 model data produced by this Estimator."""
134134
return KMeansModel(self.model_data, self.role, self.sagemaker_session)
135135

136-
def fit(self, records, mini_batch_size=5000, **kwargs):
137-
super(KMeans, self).fit(records, mini_batch_size, **kwargs)
136+
def prepare_for_training(self, records, mini_batch_size=5000, job_name=None):
137+
super(KMeans, self).prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
138138

139139
def hyperparameters(self):
140140
"""Return the SageMaker hyperparameters for training this KMeans Estimator"""

src/sagemaker/amazon/lda.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,12 @@ def create_model(self):
9393

9494
return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
9595

96-
def fit(self, records, mini_batch_size, **kwargs):
96+
def prepare_for_training(self, records, mini_batch_size, job_name=None):
9797
# mini_batch_size is required, prevent explicit calls with None
9898
if mini_batch_size is None:
9999
raise ValueError("mini_batch_size must be set")
100-
super(LDA, self).fit(records, mini_batch_size, **kwargs)
100+
101+
super(LDA, self).prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
101102

102103

103104
class LDAPredictor(RealTimePredictor):

src/sagemaker/amazon/linear_learner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,12 @@ def create_model(self):
228228

229229
return LinearLearnerModel(self.model_data, self.role, self.sagemaker_session)
230230

231-
def fit(self, records, mini_batch_size=None, **kwargs):
231+
def prepare_for_training(self, records, mini_batch_size=None, job_name=None):
232232
# mini_batch_size can't be greater than number of records or training job fails
233233
default_mini_batch_size = min(self.DEFAULT_MINI_BATCH_SIZE,
234234
max(1, int(records.num_records / self.train_instance_count)))
235235
use_mini_batch_size = mini_batch_size or default_mini_batch_size
236-
super(LinearLearner, self).fit(records, use_mini_batch_size, **kwargs)
236+
super(LinearLearner, self).prepare_for_training(records, mini_batch_size=use_mini_batch_size, job_name=job_name)
237237

238238

239239
class LinearLearnerPredictor(RealTimePredictor):

src/sagemaker/amazon/ntm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def create_model(self):
113113

114114
return NTMModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
115115

116-
def fit(self, records, mini_batch_size=None, **kwargs):
116+
def prepare_for_training(self, records, mini_batch_size, job_name=None):
117117
if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 10000):
118118
raise ValueError("mini_batch_size must be in [1, 10000]")
119-
super(NTM, self).fit(records, mini_batch_size, **kwargs)
119+
super(NTM, self).prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
120120

121121

122122
class NTMPredictor(RealTimePredictor):

src/sagemaker/amazon/randomcutforest.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,11 @@ def create_model(self):
8787

8888
return RandomCutForestModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
8989

90-
def fit(self, records, mini_batch_size=None, **kwargs):
91-
if mini_batch_size is None:
92-
mini_batch_size = RandomCutForest.MINI_BATCH_SIZE
93-
elif mini_batch_size != RandomCutForest.MINI_BATCH_SIZE:
90+
def prepare_for_training(self, records, mini_batch_size=MINI_BATCH_SIZE, job_name=None):
91+
if mini_batch_size != self.MINI_BATCH_SIZE:
9492
raise ValueError("Random Cut Forest uses a fixed mini_batch_size of {}"
95-
.format(RandomCutForest.MINI_BATCH_SIZE))
96-
super(RandomCutForest, self).fit(records, mini_batch_size, **kwargs)
93+
.format(self.MINI_BATCH_SIZE))
94+
super(RandomCutForest, self).prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
9795

9896

9997
class RandomCutForestPredictor(RealTimePredictor):

tests/unit/test_fm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -210,31 +210,31 @@ def test_call_fit(base_fit, sagemaker_session):
210210
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
211211

212212

213-
def test_call_fit_none_mini_batch_size(sagemaker_session):
213+
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
214214
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
215215

216216
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
217217
channel='train')
218-
fm.fit(data)
218+
fm.prepare_for_training(data)
219219

220220

221-
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session):
221+
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
222222
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
223223

224224
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
225225
channel='train')
226226

227227
with pytest.raises((TypeError, ValueError)):
228-
fm.fit(data, 'some')
228+
fm.prepare_for_training(data, 'some')
229229

230230

231-
def test_call_fit_wrong_value_mini_batch_size(sagemaker_session):
231+
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session):
232232
fm = FactorizationMachines(base_job_name='fm', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
233233

234234
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
235235
channel='train')
236236
with pytest.raises(ValueError):
237-
fm.fit(data, 0)
237+
fm.prepare_for_training(data, 0)
238238

239239

240240
def test_model_image(sagemaker_session):

tests/unit/test_kmeans.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -175,31 +175,33 @@ def test_call_fit(base_fit, sagemaker_session):
175175
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
176176

177177

178-
def test_call_fit_none_mini_batch_size(sagemaker_session):
178+
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
179179
kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
180180

181181
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
182182
channel='train')
183-
kmeans.fit(data)
183+
kmeans.prepare_for_training(data)
184184

185+
assert kmeans.mini_batch_size == 5000
185186

186-
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session):
187+
188+
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
187189
kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
188190

189191
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
190192
channel='train')
191193

192194
with pytest.raises((TypeError, ValueError)):
193-
kmeans.fit(data, 'some')
195+
kmeans.prepare_for_training(data, 'some')
194196

195197

196-
def test_call_fit_wrong_value_mini_batch_size(sagemaker_session):
198+
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session):
197199
kmeans = KMeans(base_job_name='kmeans', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
198200

199201
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
200202
channel='train')
201203
with pytest.raises(ValueError):
202-
kmeans.fit(data, 0)
204+
kmeans.prepare_for_training(data, 0)
203205

204206

205207
def test_model_image(sagemaker_session):

tests/unit/test_lda.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -147,32 +147,32 @@ def test_call_fit(base_fit, sagemaker_session):
147147
assert base_fit.call_args[0][1] == MINI_BATCH_SZIE
148148

149149

150-
def test_call_fit_none_mini_batch_size(sagemaker_session):
150+
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
151151
lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
152152

153153
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
154154
channel='train')
155155
with pytest.raises(ValueError):
156-
lda.fit(data, None)
156+
lda.prepare_for_training(data, None)
157157

158158

159-
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session):
159+
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
160160
lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
161161

162162
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
163163
channel='train')
164164

165165
with pytest.raises(ValueError):
166-
lda.fit(data, 'some')
166+
lda.prepare_for_training(data, 'some')
167167

168168

169-
def test_call_fit_wrong_value_mini_batch_size(sagemaker_session):
169+
def test_prepare_for_training_wrong_value_mini_batch_size(sagemaker_session):
170170
lda = LDA(base_job_name='lda', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
171171

172172
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
173173
channel='train')
174174
with pytest.raises(ValueError):
175-
lda.fit(data, 0)
175+
lda.prepare_for_training(data, 0)
176176

177177

178178
def test_model_image(sagemaker_session):

tests/unit/test_linear_learner.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -218,35 +218,27 @@ def test_optional_hyper_parameters_value(sagemaker_session, optional_hyper_param
218218
DEFAULT_MINI_BATCH_SIZE = 1000
219219

220220

221-
@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit')
222-
def test_call_fit_calculate_batch_size_1(base_fit, sagemaker_session):
221+
def test_prepare_for_training_calculate_batch_size_1(sagemaker_session):
223222
lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
224223

225224
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train')
226225

227-
lr.fit(data)
226+
lr.prepare_for_training(data)
228227

229-
base_fit.assert_called_once()
230-
assert len(base_fit.call_args[0]) == 2
231-
assert base_fit.call_args[0][0] == data
232-
assert base_fit.call_args[0][1] == 1
228+
assert lr.mini_batch_size == 1
233229

234230

235-
@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit')
236-
def test_call_fit_calculate_batch_size_2(base_fit, sagemaker_session):
231+
def test_prepare_for_training_calculate_batch_size_2(sagemaker_session):
237232
lr = LinearLearner(base_job_name='lr', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
238233

239234
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX),
240235
num_records=10000,
241236
feature_dim=FEATURE_DIM,
242237
channel='train')
243238

244-
lr.fit(data)
239+
lr.prepare_for_training(data)
245240

246-
base_fit.assert_called_once()
247-
assert len(base_fit.call_args[0]) == 2
248-
assert base_fit.call_args[0][0] == data
249-
assert base_fit.call_args[0][1] == DEFAULT_MINI_BATCH_SIZE
241+
assert lr.mini_batch_size == DEFAULT_MINI_BATCH_SIZE
250242

251243

252244
@patch('sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit')

tests/unit/test_ntm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -193,32 +193,32 @@ def test_call_fit_none_mini_batch_size(sagemaker_session):
193193
ntm.fit(data)
194194

195195

196-
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session):
196+
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
197197
ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
198198

199199
data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
200200
channel='train')
201201

202202
with pytest.raises((TypeError, ValueError)):
203-
ntm.fit(data, "some")
203+
ntm.prepare_for_training(data, "some")
204204

205205

206-
def test_call_fit_wrong_value_lower_mini_batch_size(sagemaker_session):
206+
def test_prepare_for_training_wrong_value_lower_mini_batch_size(sagemaker_session):
207207
ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
208208

209209
data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
210210
channel='train')
211211
with pytest.raises(ValueError):
212-
ntm.fit(data, 0)
212+
ntm.prepare_for_training(data, 0)
213213

214214

215-
def test_call_fit_wrong_value_upper_mini_batch_size(sagemaker_session):
215+
def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session):
216216
ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
217217

218218
data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
219219
channel='train')
220220
with pytest.raises(ValueError):
221-
ntm.fit(data, 10001)
221+
ntm.prepare_for_training(data, 10001)
222222

223223

224224
def test_model_image(sagemaker_session):

tests/unit/test_pca.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,17 @@ def test_call_fit(base_fit, sagemaker_session):
143143
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
144144

145145

146-
def test_call_fit_none_mini_batch_size(sagemaker_session):
146+
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
147147
pca = PCA(base_job_name='pca', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
148148

149149
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
150150
channel='train')
151-
pca.fit(data)
151+
pca.prepare_for_training(data)
152152

153+
assert pca.mini_batch_size == 1
153154

154-
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session):
155+
156+
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
155157
pca = PCA(base_job_name='pca', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
156158

157159
data = RecordSet('s3://{}/{}'.format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,

tests/unit/test_randomcutforest.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -141,35 +141,37 @@ def test_call_fit(base_fit, sagemaker_session):
141141
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
142142

143143

144-
def test_call_fit_none_mini_batch_size(sagemaker_session):
144+
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
145145
randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
146146
**ALL_REQ_ARGS)
147147

148148
data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
149149
channel='train')
150-
randomcutforest.fit(data)
150+
randomcutforest.prepare_for_training(data)
151151

152+
assert randomcutforest.mini_batch_size == MINI_BATCH_SIZE
152153

153-
def test_call_fit_wrong_type_mini_batch_size(sagemaker_session):
154+
155+
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
154156
randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
155157
**ALL_REQ_ARGS)
156158

157159
data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM,
158160
channel='train')
159161

160162
with pytest.raises((TypeError, ValueError)):
161-
randomcutforest.fit(data, 1234)
163+
randomcutforest.prepare_for_training(data, 1234)
162164

163165

164-
def test_call_fit_feature_dim_greater_than_max_allowed(sagemaker_session):
166+
def test_prepare_for_training_feature_dim_greater_than_max_allowed(sagemaker_session):
165167
randomcutforest = RandomCutForest(base_job_name="randomcutforest", sagemaker_session=sagemaker_session,
166168
**ALL_REQ_ARGS)
167169

168170
data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=MAX_FEATURE_DIM + 1,
169171
channel='train')
170172

171173
with pytest.raises((TypeError, ValueError)):
172-
randomcutforest.fit(data)
174+
randomcutforest.prepare_for_training(data)
173175

174176

175177
def test_model_image(sagemaker_session):

0 commit comments

Comments
 (0)