16
16
import pytest
17
17
import sagemaker
18
18
import os
19
+ import warnings
19
20
20
21
from mock import (
21
22
Mock ,
63
64
)
64
65
from tests .unit import DATA_DIR
65
66
66
- SCRIPT_FILE = "dummy_script.py"
67
- SCRIPT_PATH = os .path .join (DATA_DIR , SCRIPT_FILE )
67
+ DUMMY_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
68
68
69
69
REGION = "us-west-2"
70
70
BUCKET = "my-bucket"
@@ -129,6 +129,31 @@ def sagemaker_session(boto_session, client):
129
129
)
130
130
131
131
132
+ @pytest .fixture
133
+ def script_processor (sagemaker_session ):
134
+ return ScriptProcessor (
135
+ role = ROLE ,
136
+ image_uri = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
137
+ command = ["python3" ],
138
+ instance_type = "ml.m4.xlarge" ,
139
+ instance_count = 1 ,
140
+ volume_size_in_gb = 100 ,
141
+ volume_kms_key = "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
142
+ output_kms_key = "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
143
+ max_runtime_in_seconds = 3600 ,
144
+ base_job_name = "my_sklearn_processor" ,
145
+ env = {"my_env_variable" : "my_env_variable_value" },
146
+ tags = [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
147
+ network_config = NetworkConfig (
148
+ subnets = ["my_subnet_id" ],
149
+ security_group_ids = ["my_security_group_id" ],
150
+ enable_network_isolation = True ,
151
+ encrypt_inter_container_traffic = True ,
152
+ ),
153
+ sagemaker_session = sagemaker_session ,
154
+ )
155
+
156
+
132
157
def test_custom_step ():
133
158
step = CustomStep (
134
159
name = "MyStep" , display_name = "CustomStepDisplayName" , description = "CustomStepDescription"
@@ -326,7 +351,7 @@ def test_training_step_tensorflow(sagemaker_session):
326
351
training_epochs_parameter = ParameterInteger (name = "TrainingEpochs" , default_value = 5 )
327
352
training_batch_size_parameter = ParameterInteger (name = "TrainingBatchSize" , default_value = 500 )
328
353
estimator = TensorFlow (
329
- entry_point = os . path . join ( DATA_DIR , SCRIPT_FILE ) ,
354
+ entry_point = DUMMY_SCRIPT_PATH ,
330
355
role = ROLE ,
331
356
model_dir = False ,
332
357
image_uri = IMAGE_URI ,
@@ -403,6 +428,101 @@ def test_training_step_tensorflow(sagemaker_session):
403
428
assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
404
429
405
430
431
+ def test_training_step_profiler_warning (sagemaker_session ):
432
+ estimator = TensorFlow (
433
+ entry_point = DUMMY_SCRIPT_PATH ,
434
+ role = ROLE ,
435
+ model_dir = False ,
436
+ image_uri = IMAGE_URI ,
437
+ source_dir = "s3://mybucket/source" ,
438
+ framework_version = "2.4.1" ,
439
+ py_version = "py37" ,
440
+ disable_profiler = False ,
441
+ instance_count = 1 ,
442
+ instance_type = "ml.p3.16xlarge" ,
443
+ sagemaker_session = sagemaker_session ,
444
+ hyperparameters = {
445
+ "batch-size" : 500 ,
446
+ "epochs" : 5 ,
447
+ },
448
+ debugger_hook_config = False ,
449
+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
450
+ )
451
+
452
+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
453
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
454
+ with warnings .catch_warnings (record = True ) as w :
455
+ TrainingStep (
456
+ name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
457
+ )
458
+ assert len (w ) == 1
459
+ assert issubclass (w [- 1 ].category , UserWarning )
460
+ assert "Profiling is enabled on the provided estimator" in str (w [- 1 ].message )
461
+
462
+
463
+ def test_training_step_no_profiler_warning (sagemaker_session ):
464
+ estimator = TensorFlow (
465
+ entry_point = DUMMY_SCRIPT_PATH ,
466
+ role = ROLE ,
467
+ model_dir = False ,
468
+ image_uri = IMAGE_URI ,
469
+ source_dir = "s3://mybucket/source" ,
470
+ framework_version = "2.4.1" ,
471
+ py_version = "py37" ,
472
+ disable_profiler = True ,
473
+ instance_count = 1 ,
474
+ instance_type = "ml.p3.16xlarge" ,
475
+ sagemaker_session = sagemaker_session ,
476
+ hyperparameters = {
477
+ "batch-size" : 500 ,
478
+ "epochs" : 5 ,
479
+ },
480
+ debugger_hook_config = False ,
481
+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
482
+ )
483
+
484
+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
485
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
486
+ with warnings .catch_warnings (record = True ) as w :
487
+ # profiler disabled, cache config not None
488
+ TrainingStep (
489
+ name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
490
+ )
491
+ assert len (w ) == 0
492
+
493
+ with warnings .catch_warnings (record = True ) as w :
494
+ # profiler enabled, cache config is None
495
+ estimator .disable_profiler = False
496
+ TrainingStep (name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = None )
497
+ assert len (w ) == 0
498
+
499
+
500
+ def test_training_step_profiler_not_explicitly_enabled (sagemaker_session ):
501
+ estimator = TensorFlow (
502
+ entry_point = DUMMY_SCRIPT_PATH ,
503
+ role = ROLE ,
504
+ model_dir = False ,
505
+ image_uri = IMAGE_URI ,
506
+ source_dir = "s3://mybucket/source" ,
507
+ framework_version = "2.4.1" ,
508
+ py_version = "py37" ,
509
+ instance_count = 1 ,
510
+ instance_type = "ml.p3.16xlarge" ,
511
+ sagemaker_session = sagemaker_session ,
512
+ hyperparameters = {
513
+ "batch-size" : 500 ,
514
+ "epochs" : 5 ,
515
+ },
516
+ debugger_hook_config = False ,
517
+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
518
+ )
519
+
520
+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
521
+ step = TrainingStep (name = "MyTrainingStep" , estimator = estimator , inputs = inputs )
522
+ step_request = step .to_request ()
523
+ assert step_request ["Arguments" ]["ProfilerRuleConfigurations" ] is None
524
+
525
+
406
526
def test_processing_step (sagemaker_session ):
407
527
processing_input_data_uri_parameter = ParameterString (
408
528
name = "ProcessingInputDataUri" , default_value = f"s3://{ BUCKET } /processing_manifest"
@@ -473,28 +593,42 @@ def test_processing_step(sagemaker_session):
473
593
474
594
475
595
@patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
476
- def test_processing_step_normalizes_args (mock_normalize_args , sagemaker_session ):
477
- processor = ScriptProcessor (
478
- role = ROLE ,
479
- image_uri = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
480
- command = ["python3" ],
481
- instance_type = "ml.m4.xlarge" ,
482
- instance_count = 1 ,
483
- volume_size_in_gb = 100 ,
484
- volume_kms_key = "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
485
- output_kms_key = "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
486
- max_runtime_in_seconds = 3600 ,
487
- base_job_name = "my_sklearn_processor" ,
488
- env = {"my_env_variable" : "my_env_variable_value" },
489
- tags = [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
490
- network_config = NetworkConfig (
491
- subnets = ["my_subnet_id" ],
492
- security_group_ids = ["my_security_group_id" ],
493
- enable_network_isolation = True ,
494
- encrypt_inter_container_traffic = True ,
495
- ),
496
- sagemaker_session = sagemaker_session ,
596
+ def test_processing_step_normalizes_args_with_local_code (mock_normalize_args , script_processor ):
597
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
598
+ inputs = [
599
+ ProcessingInput (
600
+ source = f"s3://{ BUCKET } /processing_manifest" ,
601
+ destination = "processing_manifest" ,
602
+ )
603
+ ]
604
+ outputs = [
605
+ ProcessingOutput (
606
+ source = f"s3://{ BUCKET } /processing_manifest" ,
607
+ destination = "processing_manifest" ,
608
+ )
609
+ ]
610
+ step = ProcessingStep (
611
+ name = "MyProcessingStep" ,
612
+ processor = script_processor ,
613
+ code = DUMMY_SCRIPT_PATH ,
614
+ inputs = inputs ,
615
+ outputs = outputs ,
616
+ job_arguments = ["arg1" , "arg2" ],
617
+ cache_config = cache_config ,
618
+ )
619
+ mock_normalize_args .return_value = [step .inputs , step .outputs ]
620
+ step .to_request ()
621
+ mock_normalize_args .assert_called_with (
622
+ job_name = "MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db" ,
623
+ arguments = step .job_arguments ,
624
+ inputs = step .inputs ,
625
+ outputs = step .outputs ,
626
+ code = step .code ,
497
627
)
628
+
629
+
630
+ @patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
631
+ def test_processing_step_normalizes_args_with_s3_code (mock_normalize_args , script_processor ):
498
632
cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
499
633
inputs = [
500
634
ProcessingInput (
@@ -510,8 +644,8 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
510
644
]
511
645
step = ProcessingStep (
512
646
name = "MyProcessingStep" ,
513
- processor = processor ,
514
- code = "foo.py " ,
647
+ processor = script_processor ,
648
+ code = "s3:// foo" ,
515
649
inputs = inputs ,
516
650
outputs = outputs ,
517
651
job_arguments = ["arg1" , "arg2" ],
@@ -520,13 +654,48 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
520
654
mock_normalize_args .return_value = [step .inputs , step .outputs ]
521
655
step .to_request ()
522
656
mock_normalize_args .assert_called_with (
657
+ job_name = None ,
523
658
arguments = step .job_arguments ,
524
659
inputs = step .inputs ,
525
660
outputs = step .outputs ,
526
661
code = step .code ,
527
662
)
528
663
529
664
665
+ @patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
666
+ def test_processing_step_normalizes_args_with_no_code (mock_normalize_args , script_processor ):
667
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
668
+ inputs = [
669
+ ProcessingInput (
670
+ source = f"s3://{ BUCKET } /processing_manifest" ,
671
+ destination = "processing_manifest" ,
672
+ )
673
+ ]
674
+ outputs = [
675
+ ProcessingOutput (
676
+ source = f"s3://{ BUCKET } /processing_manifest" ,
677
+ destination = "processing_manifest" ,
678
+ )
679
+ ]
680
+ step = ProcessingStep (
681
+ name = "MyProcessingStep" ,
682
+ processor = script_processor ,
683
+ inputs = inputs ,
684
+ outputs = outputs ,
685
+ job_arguments = ["arg1" , "arg2" ],
686
+ cache_config = cache_config ,
687
+ )
688
+ mock_normalize_args .return_value = [step .inputs , step .outputs ]
689
+ step .to_request ()
690
+ mock_normalize_args .assert_called_with (
691
+ job_name = None ,
692
+ arguments = step .job_arguments ,
693
+ inputs = step .inputs ,
694
+ outputs = step .outputs ,
695
+ code = None ,
696
+ )
697
+
698
+
530
699
def test_create_model_step (sagemaker_session ):
531
700
model = Model (
532
701
image_uri = IMAGE_URI ,
0 commit comments