@@ -141,35 +141,37 @@ def test_call_fit(base_fit, sagemaker_session):
141
141
assert base_fit .call_args [0 ][1 ] == MINI_BATCH_SIZE
142
142
143
143
144
- def test_call_fit_none_mini_batch_size (sagemaker_session ):
144
+ def test_prepare_for_training_no_mini_batch_size (sagemaker_session ):
145
145
randomcutforest = RandomCutForest (base_job_name = "randomcutforest" , sagemaker_session = sagemaker_session ,
146
146
** ALL_REQ_ARGS )
147
147
148
148
data = RecordSet ("s3://{}/{}" .format (BUCKET_NAME , PREFIX ), num_records = 1 , feature_dim = FEATURE_DIM ,
149
149
channel = 'train' )
150
- randomcutforest .fit (data )
150
+ randomcutforest .prepare_for_training (data )
151
151
152
+ assert randomcutforest .mini_batch_size == MINI_BATCH_SIZE
152
153
153
- def test_call_fit_wrong_type_mini_batch_size (sagemaker_session ):
154
+
155
+ def test_prepare_for_training_wrong_type_mini_batch_size (sagemaker_session ):
154
156
randomcutforest = RandomCutForest (base_job_name = "randomcutforest" , sagemaker_session = sagemaker_session ,
155
157
** ALL_REQ_ARGS )
156
158
157
159
data = RecordSet ("s3://{}/{}" .format (BUCKET_NAME , PREFIX ), num_records = 1 , feature_dim = FEATURE_DIM ,
158
160
channel = 'train' )
159
161
160
162
with pytest .raises ((TypeError , ValueError )):
161
- randomcutforest .fit (data , 1234 )
163
+ randomcutforest .prepare_for_training (data , 1234 )
162
164
163
165
164
- def test_call_fit_feature_dim_greater_than_max_allowed (sagemaker_session ):
166
+ def test_prepare_for_training_feature_dim_greater_than_max_allowed (sagemaker_session ):
165
167
randomcutforest = RandomCutForest (base_job_name = "randomcutforest" , sagemaker_session = sagemaker_session ,
166
168
** ALL_REQ_ARGS )
167
169
168
170
data = RecordSet ("s3://{}/{}" .format (BUCKET_NAME , PREFIX ), num_records = 1 , feature_dim = MAX_FEATURE_DIM + 1 ,
169
171
channel = 'train' )
170
172
171
173
with pytest .raises ((TypeError , ValueError )):
172
- randomcutforest .fit (data )
174
+ randomcutforest .prepare_for_training (data )
173
175
174
176
175
177
def test_model_image (sagemaker_session ):
0 commit comments