@@ -308,6 +308,38 @@ def test_s3_input_all_arguments():
308
308
IN_PROGRESS_DESCRIBE_JOB_RESULT = dict (DEFAULT_EXPECTED_TRAIN_JOB_ARGS )
309
309
IN_PROGRESS_DESCRIBE_JOB_RESULT .update ({'TrainingJobStatus' : 'InProgress' })
310
310
311
+ COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT = {
312
+ 'TransformJobStatus' : 'Completed' ,
313
+ 'ModelName' : 'some-model' ,
314
+ 'TransformJobName' : JOB_NAME ,
315
+ 'TransformResources' : {
316
+ 'InstanceCount' : INSTANCE_COUNT ,
317
+ 'InstanceType' : INSTANCE_TYPE
318
+ },
319
+ 'TransformEndTime' : datetime .datetime (2018 , 2 , 17 , 7 , 19 , 34 , 953000 ),
320
+ 'TransformStartTime' : datetime .datetime (2018 , 2 , 17 , 7 , 15 , 0 , 103000 ),
321
+ 'TransformOutput' : {
322
+ 'AssembleWith' : 'None' ,
323
+ 'KmsKeyId' : '' ,
324
+ 'S3OutputPath' : S3_OUTPUT
325
+ },
326
+ 'TransformInput' : {
327
+ 'CompressionType' : 'None' ,
328
+ 'ContentType' : 'text/csv' ,
329
+ 'DataSource' : {
330
+ 'S3DataType' : 'S3Prefix' ,
331
+ 'S3Uri' : S3_INPUT_URI
332
+ },
333
+ 'SplitType' : 'Line'
334
+ }
335
+ }
336
+
337
+ STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT = dict (COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT )
338
+ STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT .update ({'TransformJobStatus' : 'Stopped' })
339
+
340
+ IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT = dict (COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT )
341
+ IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT .update ({'TransformJobStatus' : 'InProgress' })
342
+
311
343
312
344
@pytest .fixture ()
313
345
def sagemaker_session ():
@@ -653,6 +685,7 @@ def sagemaker_session_complete():
653
685
boto_mock .client ('logs' ).get_log_events .side_effect = DEFAULT_LOG_EVENTS
654
686
ims = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
655
687
ims .sagemaker_client .describe_training_job .return_value = COMPLETED_DESCRIBE_JOB_RESULT
688
+ ims .sagemaker_client .describe_transform_job .return_value = COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT
656
689
return ims
657
690
658
691
@@ -663,6 +696,7 @@ def sagemaker_session_stopped():
663
696
boto_mock .client ('logs' ).get_log_events .side_effect = DEFAULT_LOG_EVENTS
664
697
ims = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
665
698
ims .sagemaker_client .describe_training_job .return_value = STOPPED_DESCRIBE_JOB_RESULT
699
+ ims .sagemaker_client .describe_transform_job .return_value = STOPPED_DESCRIBE_TRANSFORM_JOB_RESULT
666
700
return ims
667
701
668
702
@@ -675,6 +709,9 @@ def sagemaker_session_ready_lifecycle():
675
709
ims .sagemaker_client .describe_training_job .side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT ,
676
710
IN_PROGRESS_DESCRIBE_JOB_RESULT ,
677
711
COMPLETED_DESCRIBE_JOB_RESULT ]
712
+ ims .sagemaker_client .describe_transform_job .side_effect = [IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT ,
713
+ IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT ,
714
+ COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT ]
678
715
return ims
679
716
680
717
@@ -687,6 +724,9 @@ def sagemaker_session_full_lifecycle():
687
724
ims .sagemaker_client .describe_training_job .side_effect = [IN_PROGRESS_DESCRIBE_JOB_RESULT ,
688
725
IN_PROGRESS_DESCRIBE_JOB_RESULT ,
689
726
COMPLETED_DESCRIBE_JOB_RESULT ]
727
+ ims .sagemaker_client .describe_transform_job .side_effect = [IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT ,
728
+ IN_PROGRESS_DESCRIBE_TRANSFORM_JOB_RESULT ,
729
+ COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT ]
690
730
return ims
691
731
692
732
@@ -740,6 +780,56 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
740
780
call (0 , 'hi there #2a' ), call (0 , 'hi there #3' )]
741
781
742
782
783
+ @patch ('sagemaker.logs.ColorWrap' )
784
+ def test_logs_for_transform_job_no_wait (cw , sagemaker_session_complete ):
785
+ ims = sagemaker_session_complete
786
+ ims .logs_for_transform_job (JOB_NAME )
787
+ ims .sagemaker_client .describe_transform_job .assert_called_once_with (TransformJobName = JOB_NAME )
788
+ cw ().assert_called_with (0 , 'hi there #1' )
789
+
790
+
791
+ @patch ('sagemaker.logs.ColorWrap' )
792
+ def test_logs_for_transform_job_no_wait_stopped_job (cw , sagemaker_session_stopped ):
793
+ ims = sagemaker_session_stopped
794
+ ims .logs_for_transform_job (JOB_NAME )
795
+ ims .sagemaker_client .describe_transform_job .assert_called_once_with (TransformJobName = JOB_NAME )
796
+ cw ().assert_called_with (0 , 'hi there #1' )
797
+
798
+
799
+ @patch ('sagemaker.logs.ColorWrap' )
800
+ def test_logs_for_transform_job_wait_on_completed (cw , sagemaker_session_complete ):
801
+ ims = sagemaker_session_complete
802
+ ims .logs_for_transform_job (JOB_NAME , wait = True , poll = 0 )
803
+ assert ims .sagemaker_client .describe_transform_job .call_args_list == [call (TransformJobName = JOB_NAME ,)]
804
+ cw ().assert_called_with (0 , 'hi there #1' )
805
+
806
+
807
+ @patch ('sagemaker.logs.ColorWrap' )
808
+ def test_logs_for_transform_job_wait_on_stopped (cw , sagemaker_session_stopped ):
809
+ ims = sagemaker_session_stopped
810
+ ims .logs_for_transform_job (JOB_NAME , wait = True , poll = 0 )
811
+ assert ims .sagemaker_client .describe_transform_job .call_args_list == [call (TransformJobName = JOB_NAME ,)]
812
+ cw ().assert_called_with (0 , 'hi there #1' )
813
+
814
+
815
+ @patch ('sagemaker.logs.ColorWrap' )
816
+ def test_logs_for_transform_job_no_wait_on_running (cw , sagemaker_session_ready_lifecycle ):
817
+ ims = sagemaker_session_ready_lifecycle
818
+ ims .logs_for_transform_job (JOB_NAME )
819
+ assert ims .sagemaker_client .describe_transform_job .call_args_list == [call (TransformJobName = JOB_NAME ,)]
820
+ cw ().assert_called_with (0 , 'hi there #1' )
821
+
822
+
823
+ @patch ('sagemaker.logs.ColorWrap' )
824
+ @patch ('time.time' , side_effect = [0 , 30 , 60 , 90 , 120 , 150 , 180 ])
825
+ def test_logs_for_transform_job_full_lifecycle (time , cw , sagemaker_session_full_lifecycle ):
826
+ ims = sagemaker_session_full_lifecycle
827
+ ims .logs_for_transform_job (JOB_NAME , wait = True , poll = 0 )
828
+ assert ims .sagemaker_client .describe_transform_job .call_args_list == [call (TransformJobName = JOB_NAME ,)] * 3
829
+ assert cw ().call_args_list == [call (0 , 'hi there #1' ), call (0 , 'hi there #2' ),
830
+ call (0 , 'hi there #2a' ), call (0 , 'hi there #3' )]
831
+
832
+
743
833
MODEL_NAME = 'some-model'
744
834
PRIMARY_CONTAINER = {
745
835
'Environment' : {},
0 commit comments