16
16
import pytest
17
17
import sagemaker
18
18
import os
19
+ import warnings
19
20
20
21
from mock import (
21
22
Mock ,
63
64
)
64
65
from tests .unit import DATA_DIR
65
66
66
- SCRIPT_FILE = "dummy_script.py"
67
- SCRIPT_PATH = os .path .join (DATA_DIR , SCRIPT_FILE )
67
+ DUMMY_SCRIPT_PATH = os .path .join (DATA_DIR , "dummy_script.py" )
68
68
69
69
REGION = "us-west-2"
70
70
BUCKET = "my-bucket"
@@ -129,6 +129,31 @@ def sagemaker_session(boto_session, client):
129
129
)
130
130
131
131
132
+ @pytest .fixture
133
+ def script_processor (sagemaker_session ):
134
+ return ScriptProcessor (
135
+ role = ROLE ,
136
+ image_uri = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
137
+ command = ["python3" ],
138
+ instance_type = "ml.m4.xlarge" ,
139
+ instance_count = 1 ,
140
+ volume_size_in_gb = 100 ,
141
+ volume_kms_key = "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
142
+ output_kms_key = "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
143
+ max_runtime_in_seconds = 3600 ,
144
+ base_job_name = "my_sklearn_processor" ,
145
+ env = {"my_env_variable" : "my_env_variable_value" },
146
+ tags = [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
147
+ network_config = NetworkConfig (
148
+ subnets = ["my_subnet_id" ],
149
+ security_group_ids = ["my_security_group_id" ],
150
+ enable_network_isolation = True ,
151
+ encrypt_inter_container_traffic = True ,
152
+ ),
153
+ sagemaker_session = sagemaker_session ,
154
+ )
155
+
156
+
132
157
def test_custom_step ():
133
158
step = CustomStep (
134
159
name = "MyStep" , display_name = "CustomStepDisplayName" , description = "CustomStepDescription"
@@ -326,7 +351,7 @@ def test_training_step_tensorflow(sagemaker_session):
326
351
training_epochs_parameter = ParameterInteger (name = "TrainingEpochs" , default_value = 5 )
327
352
training_batch_size_parameter = ParameterInteger (name = "TrainingBatchSize" , default_value = 500 )
328
353
estimator = TensorFlow (
329
- entry_point = os . path . join ( DATA_DIR , SCRIPT_FILE ) ,
354
+ entry_point = DUMMY_SCRIPT_PATH ,
330
355
role = ROLE ,
331
356
model_dir = False ,
332
357
image_uri = IMAGE_URI ,
@@ -403,6 +428,75 @@ def test_training_step_tensorflow(sagemaker_session):
403
428
assert step .properties .TrainingJobName .expr == {"Get" : "Steps.MyTrainingStep.TrainingJobName" }
404
429
405
430
431
+ def test_training_step_profiler_warning (sagemaker_session ):
432
+ estimator = TensorFlow (
433
+ entry_point = DUMMY_SCRIPT_PATH ,
434
+ role = ROLE ,
435
+ model_dir = False ,
436
+ image_uri = IMAGE_URI ,
437
+ source_dir = "s3://mybucket/source" ,
438
+ framework_version = "2.4.1" ,
439
+ py_version = "py37" ,
440
+ disable_profiler = False ,
441
+ instance_count = 1 ,
442
+ instance_type = "ml.p3.16xlarge" ,
443
+ sagemaker_session = sagemaker_session ,
444
+ hyperparameters = {
445
+ "batch-size" : 500 ,
446
+ "epochs" : 5 ,
447
+ },
448
+ debugger_hook_config = False ,
449
+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
450
+ )
451
+
452
+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
453
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
454
+ with warnings .catch_warnings (record = True ) as w :
455
+ TrainingStep (
456
+ name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
457
+ )
458
+ assert len (w ) == 1
459
+ assert issubclass (w [- 1 ].category , UserWarning )
460
+ assert "Profiling is enabled on the provided estimator" in str (w [- 1 ].message )
461
+
462
+
463
+ def test_training_step_no_profiler_warning (sagemaker_session ):
464
+ estimator = TensorFlow (
465
+ entry_point = DUMMY_SCRIPT_PATH ,
466
+ role = ROLE ,
467
+ model_dir = False ,
468
+ image_uri = IMAGE_URI ,
469
+ source_dir = "s3://mybucket/source" ,
470
+ framework_version = "2.4.1" ,
471
+ py_version = "py37" ,
472
+ disable_profiler = True ,
473
+ instance_count = 1 ,
474
+ instance_type = "ml.p3.16xlarge" ,
475
+ sagemaker_session = sagemaker_session ,
476
+ hyperparameters = {
477
+ "batch-size" : 500 ,
478
+ "epochs" : 5 ,
479
+ },
480
+ debugger_hook_config = False ,
481
+ distribution = {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
482
+ )
483
+
484
+ inputs = TrainingInput (s3_data = f"s3://{ BUCKET } /train_manifest" )
485
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
486
+ with warnings .catch_warnings (record = True ) as w :
487
+ # profiler disabled, cache config not None
488
+ TrainingStep (
489
+ name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = cache_config
490
+ )
491
+ assert len (w ) == 0
492
+
493
+ with warnings .catch_warnings (record = True ) as w :
494
+ # profiler enabled, cache config is None
495
+ estimator .disable_profiler = False
496
+ TrainingStep (name = "MyTrainingStep" , estimator = estimator , inputs = inputs , cache_config = None )
497
+ assert len (w ) == 0
498
+
499
+
406
500
def test_processing_step (sagemaker_session ):
407
501
processing_input_data_uri_parameter = ParameterString (
408
502
name = "ProcessingInputDataUri" , default_value = f"s3://{ BUCKET } /processing_manifest"
@@ -473,28 +567,42 @@ def test_processing_step(sagemaker_session):
473
567
474
568
475
569
@patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
476
- def test_processing_step_normalizes_args (mock_normalize_args , sagemaker_session ):
477
- processor = ScriptProcessor (
478
- role = ROLE ,
479
- image_uri = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
480
- command = ["python3" ],
481
- instance_type = "ml.m4.xlarge" ,
482
- instance_count = 1 ,
483
- volume_size_in_gb = 100 ,
484
- volume_kms_key = "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
485
- output_kms_key = "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
486
- max_runtime_in_seconds = 3600 ,
487
- base_job_name = "my_sklearn_processor" ,
488
- env = {"my_env_variable" : "my_env_variable_value" },
489
- tags = [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
490
- network_config = NetworkConfig (
491
- subnets = ["my_subnet_id" ],
492
- security_group_ids = ["my_security_group_id" ],
493
- enable_network_isolation = True ,
494
- encrypt_inter_container_traffic = True ,
495
- ),
496
- sagemaker_session = sagemaker_session ,
570
+ def test_processing_step_normalizes_args_with_local_code (mock_normalize_args , script_processor ):
571
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
572
+ inputs = [
573
+ ProcessingInput (
574
+ source = f"s3://{ BUCKET } /processing_manifest" ,
575
+ destination = "processing_manifest" ,
576
+ )
577
+ ]
578
+ outputs = [
579
+ ProcessingOutput (
580
+ source = f"s3://{ BUCKET } /processing_manifest" ,
581
+ destination = "processing_manifest" ,
582
+ )
583
+ ]
584
+ step = ProcessingStep (
585
+ name = "MyProcessingStep" ,
586
+ processor = script_processor ,
587
+ code = DUMMY_SCRIPT_PATH ,
588
+ inputs = inputs ,
589
+ outputs = outputs ,
590
+ job_arguments = ["arg1" , "arg2" ],
591
+ cache_config = cache_config ,
497
592
)
593
+ mock_normalize_args .return_value = [step .inputs , step .outputs ]
594
+ step .to_request ()
595
+ mock_normalize_args .assert_called_with (
596
+ job_name = "MyProcessingStep-3e89f0c7e101c356cbedf27d9d27e9db" ,
597
+ arguments = step .job_arguments ,
598
+ inputs = step .inputs ,
599
+ outputs = step .outputs ,
600
+ code = step .code ,
601
+ )
602
+
603
+
604
+ @patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
605
+ def test_processing_step_normalizes_args_with_s3_code (mock_normalize_args , script_processor ):
498
606
cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
499
607
inputs = [
500
608
ProcessingInput (
@@ -510,8 +618,8 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
510
618
]
511
619
step = ProcessingStep (
512
620
name = "MyProcessingStep" ,
513
- processor = processor ,
514
- code = "foo.py " ,
621
+ processor = script_processor ,
622
+ code = "s3:// foo" ,
515
623
inputs = inputs ,
516
624
outputs = outputs ,
517
625
job_arguments = ["arg1" , "arg2" ],
@@ -520,13 +628,48 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
520
628
mock_normalize_args .return_value = [step .inputs , step .outputs ]
521
629
step .to_request ()
522
630
mock_normalize_args .assert_called_with (
631
+ job_name = None ,
523
632
arguments = step .job_arguments ,
524
633
inputs = step .inputs ,
525
634
outputs = step .outputs ,
526
635
code = step .code ,
527
636
)
528
637
529
638
639
+ @patch ("sagemaker.processing.ScriptProcessor._normalize_args" )
640
+ def test_processing_step_normalizes_args_with_no_code (mock_normalize_args , script_processor ):
641
+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" )
642
+ inputs = [
643
+ ProcessingInput (
644
+ source = f"s3://{ BUCKET } /processing_manifest" ,
645
+ destination = "processing_manifest" ,
646
+ )
647
+ ]
648
+ outputs = [
649
+ ProcessingOutput (
650
+ source = f"s3://{ BUCKET } /processing_manifest" ,
651
+ destination = "processing_manifest" ,
652
+ )
653
+ ]
654
+ step = ProcessingStep (
655
+ name = "MyProcessingStep" ,
656
+ processor = script_processor ,
657
+ inputs = inputs ,
658
+ outputs = outputs ,
659
+ job_arguments = ["arg1" , "arg2" ],
660
+ cache_config = cache_config ,
661
+ )
662
+ mock_normalize_args .return_value = [step .inputs , step .outputs ]
663
+ step .to_request ()
664
+ mock_normalize_args .assert_called_with (
665
+ job_name = None ,
666
+ arguments = step .job_arguments ,
667
+ inputs = step .inputs ,
668
+ outputs = step .outputs ,
669
+ code = None ,
670
+ )
671
+
672
+
530
673
def test_create_model_step (sagemaker_session ):
531
674
model = Model (
532
675
image_uri = IMAGE_URI ,
0 commit comments