Skip to content

Commit bff4112

Browse files
committed
Add test cases for transformer logs
1 parent 790fe8f commit bff4112

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

src/sagemaker/session.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
12251225
billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count
12261226
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
12271227

1228-
def logs_for_transform_job(self, job_name, wait=False, poll=10):
1228+
def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning
12291229
"""Display the logs for a given transform job, optionally tailing them until the
12301230
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
12311231
based on which instance the log entry is from.
@@ -1277,7 +1277,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10):
12771277
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
12781278
# the job was marked complete.
12791279
last_describe_job_call = time.time()
1280-
last_description = description
12811280
while True:
12821281
if len(stream_names) < instance_count:
12831282
# Log streams are created whenever a container starts writing to stdout/err, so this list

tests/unit/test_session.py

+90
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,38 @@ def test_s3_input_all_arguments():
308308
IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
309309
IN_PROGRESS_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'InProgress'})
310310

311+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = {
312+
'TransformJobStatus': 'Completed',
313+
'ModelName': 'some-model',
314+
'TransformJobName': JOB_NAME,
315+
'TransformResources': {
316+
'InstanceCount': INSTANCE_COUNT,
317+
'InstanceType': INSTANCE_TYPE
318+
},
319+
'TransformEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000),
320+
'TransformStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000),
321+
'TransformOutput': {
322+
'AssembleWith': 'None',
323+
'KmsKeyId': '',
324+
'S3OutputPath': S3_OUTPUT
325+
},
326+
'TransformInput': {
327+
'CompressionType': 'None',
328+
'ContentType': 'text/csv',
329+
'DataSource': {
330+
'S3DataType': 'S3Prefix',
331+
'S3Uri': S3_INPUT_URI
332+
},
333+
'SplitType': 'Line'
334+
}
335+
}
336+
337+
STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT)
338+
STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'Stopped'})
339+
340+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict(COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT)
341+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT.update({'TransformJobStatus': 'InProgress'})
342+
311343

312344
@pytest.fixture()
313345
def sagemaker_session():
@@ -653,6 +685,7 @@ def sagemaker_session_complete():
653685
boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS
654686
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
655687
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
688+
ims.sagemaker_client.describe_transform_job.return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT
656689
return ims
657690

658691

@@ -663,6 +696,7 @@ def sagemaker_session_stopped():
663696
boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS
664697
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
665698
ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT
699+
ims.sagemaker_client.describe_transform_job.return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT
666700
return ims
667701

668702

@@ -675,6 +709,9 @@ def sagemaker_session_ready_lifecycle():
675709
ims.sagemaker_client.describe_training_job.side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT,
676710
IN_PROGRESS_DESCRIBE_JOB_RESULT,
677711
COMPLETED_DESCRIBE_JOB_RESULT]
712+
ims.sagemaker_client.describe_transform_job.side_effect = [IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
713+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
714+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT]
678715
return ims
679716

680717

@@ -687,6 +724,9 @@ def sagemaker_session_full_lifecycle():
687724
ims.sagemaker_client.describe_training_job.side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT,
688725
IN_PROGRESS_DESCRIBE_JOB_RESULT,
689726
COMPLETED_DESCRIBE_JOB_RESULT]
727+
ims.sagemaker_client.describe_transform_job.side_effect = [IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
728+
IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT,
729+
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT]
690730
return ims
691731

692732

@@ -740,6 +780,56 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
740780
call(0, 'hi there #2a'), call(0, 'hi there #3')]
741781

742782

783+
@patch('sagemaker.logs.ColorWrap')
784+
def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete):
785+
ims = sagemaker_session_complete
786+
ims.logs_for_transform_job(JOB_NAME)
787+
ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME)
788+
cw().assert_called_with(0, 'hi there #1')
789+
790+
791+
@patch('sagemaker.logs.ColorWrap')
792+
def test_logs_for_transform_job_no_wait_stopped_job(cw, sagemaker_session_stopped):
793+
ims = sagemaker_session_stopped
794+
ims.logs_for_transform_job(JOB_NAME)
795+
ims.sagemaker_client.describe_transform_job.assert_called_once_with(TransformJobName=JOB_NAME)
796+
cw().assert_called_with(0, 'hi there #1')
797+
798+
799+
@patch('sagemaker.logs.ColorWrap')
800+
def test_logs_for_transform_job_wait_on_completed(cw, sagemaker_session_complete):
801+
ims = sagemaker_session_complete
802+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
803+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
804+
cw().assert_called_with(0, 'hi there #1')
805+
806+
807+
@patch('sagemaker.logs.ColorWrap')
808+
def test_logs_for_transform_job_wait_on_stopped(cw, sagemaker_session_stopped):
809+
ims = sagemaker_session_stopped
810+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
811+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
812+
cw().assert_called_with(0, 'hi there #1')
813+
814+
815+
@patch('sagemaker.logs.ColorWrap')
816+
def test_logs_for_transform_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle):
817+
ims = sagemaker_session_ready_lifecycle
818+
ims.logs_for_transform_job(JOB_NAME)
819+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)]
820+
cw().assert_called_with(0, 'hi there #1')
821+
822+
823+
@patch('sagemaker.logs.ColorWrap')
824+
@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180])
825+
def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle):
826+
ims = sagemaker_session_full_lifecycle
827+
ims.logs_for_transform_job(JOB_NAME, wait=True, poll=0)
828+
assert ims.sagemaker_client.describe_transform_job.call_args_list == [call(TransformJobName=JOB_NAME,)] * 3
829+
assert cw().call_args_list == [call(0, 'hi there #1'), call(0, 'hi there #2'),
830+
call(0, 'hi there #2a'), call(0, 'hi there #3')]
831+
832+
743833
MODEL_NAME = 'some-model'
744834
PRIMARY_CONTAINER = {
745835
'Environment': {},

0 commit comments

Comments
 (0)