@@ -197,7 +197,8 @@ def test_trainer_name():
197
197
'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
198
198
}
199
199
session = create_sagemaker_session (describe_training_result )
200
- trainer = TrainingJobAnalytics ("my-training-job" , ["metric" ], sagemaker_session = session )
200
+ trainer = TrainingJobAnalytics (training_job_name = "my-training-job" , metric_names = ["metric" ],
201
+ sagemaker_session = session )
201
202
assert trainer .name == "my-training-job"
202
203
assert str (trainer ).find ("my-training-job" ) != - 1
203
204
@@ -231,7 +232,8 @@ def _metric_stats_results():
231
232
def test_trainer_dataframe ():
232
233
session = create_sagemaker_session (describe_training_result = _describe_training_result (),
233
234
metric_stats_results = _metric_stats_results ())
234
- trainer = TrainingJobAnalytics ("my-training-job" , ["train:acc" ], sagemaker_session = session )
235
+ trainer = TrainingJobAnalytics (training_job_name = "my-training-job" , metric_names = ["train:acc" ],
236
+ sagemaker_session = session )
235
237
236
238
df = trainer .dataframe ()
237
239
assert df is not None
@@ -245,3 +247,20 @@ def test_trainer_dataframe():
245
247
trainer .export_csv (tmp_name )
246
248
assert os .path .isfile (tmp_name )
247
249
os .unlink (tmp_name )
250
+
251
+
252
+ def test_start_time_end_time_and_period_specified ():
253
+ describe_training_result = {
254
+ 'TrainingStartTime' : datetime .datetime (2018 , 5 , 16 , 1 , 2 , 3 ),
255
+ 'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
256
+ }
257
+ session = create_sagemaker_session (describe_training_result )
258
+ start_time = datetime .datetime (2018 , 5 , 16 , 1 , 3 , 4 )
259
+ end_time = datetime .datetime (2018 , 5 , 16 , 5 , 1 , 1 )
260
+ period = 300
261
+ trainer = TrainingJobAnalytics (training_job_name = "my-training-job" , metric_names = ["metric" ],
262
+ sagemaker_session = session , start_time = start_time , end_time = end_time , period = period )
263
+
264
+ assert trainer ._time_interval ['start_time' ] == start_time
265
+ assert trainer ._time_interval ['end_time' ] == end_time
266
+ assert trainer ._period == period
0 commit comments