69
69
from tests .unit import DATA_DIR
70
70
from tests .unit .sagemaker .workflow .conftest import ROLE , BUCKET , IMAGE_URI , INSTANCE_TYPE
71
71
72
+ HF_INSTANCE_TYPE = "ml.p3.2xlarge"
72
73
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
73
74
LOCAL_SCRIPT_PATH = os .path .join (DATA_DIR , "workflow/abalone/preprocessing.py" )
74
75
SPARK_APP_JAR_PATH = os .path .join (
142
143
pytorch_version = "1.7" ,
143
144
role = ROLE ,
144
145
instance_count = 1 ,
145
- instance_type = "ml.p3.2xlarge" ,
146
+ instance_type = HF_INSTANCE_TYPE ,
146
147
),
147
148
{"code" : DUMMY_S3_SCRIPT_PATH },
148
149
),
@@ -446,8 +447,15 @@ def test_processing_step_with_framework_processor(
446
447
):
447
448
448
449
processor , run_inputs = framework_processor
450
+ default_instance_type = (
451
+ HF_INSTANCE_TYPE if type (processor ) is HuggingFaceProcessor else INSTANCE_TYPE
452
+ )
453
+ instance_type_param = ParameterString (
454
+ name = "ProcessingInstanceType" , default_value = default_instance_type
455
+ )
449
456
processor .sagemaker_session = pipeline_session
450
457
processor .role = ROLE
458
+ processor .instance_type = instance_type_param
451
459
452
460
processor .volume_kms_key = "volume-kms-key"
453
461
processor .network_config = network_config
@@ -465,6 +473,7 @@ def test_processing_step_with_framework_processor(
465
473
name = "MyPipeline" ,
466
474
steps = [step ],
467
475
sagemaker_session = pipeline_session ,
476
+ parameters = [instance_type_param ],
468
477
)
469
478
470
479
step_args = get_step_args_helper (step_args , "Processing" )
@@ -475,6 +484,12 @@ def test_processing_step_with_framework_processor(
475
484
step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
476
485
== processing_output .destination
477
486
)
487
+ assert (
488
+ type (step_args ["ProcessingResources" ]["ClusterConfig" ]["InstanceType" ]) is ParameterString
489
+ )
490
+ step_args ["ProcessingResources" ]["ClusterConfig" ]["InstanceType" ] = step_args [
491
+ "ProcessingResources"
492
+ ]["ClusterConfig" ]["InstanceType" ].expr
478
493
479
494
del step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
480
495
del step_def ["Arguments" ]["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
0 commit comments