66
66
ConditionIn ,
67
67
ConditionLessThanOrEqualTo ,
68
68
)
69
- from sagemaker .workflow .condition_step import ConditionStep , JsonGet
69
+ from sagemaker .workflow .condition_step import ConditionStep
70
70
from sagemaker .workflow .callback_step import CallbackStep , CallbackOutput , CallbackOutputTypeEnum
71
71
from sagemaker .workflow .lambda_step import LambdaStep , LambdaOutput , LambdaOutputTypeEnum
72
- from sagemaker .workflow .properties import PropertyFile
73
72
from sagemaker .wrangler .processing import DataWranglerProcessor
74
73
from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
75
74
from sagemaker .workflow .execution_variables import ExecutionVariables
76
- from sagemaker .workflow .functions import Join
75
+ from sagemaker .workflow .functions import Join , JsonGet
77
76
from sagemaker .wrangler .ingestion import generate_data_ingestion_flow_from_s3_input
78
77
from sagemaker .workflow .parameters import (
79
78
ParameterInteger ,
87
86
TuningStep ,
88
87
TransformStep ,
89
88
TransformInput ,
89
+ PropertyFile ,
90
90
)
91
91
from sagemaker .workflow .step_collections import RegisterModel
92
92
from sagemaker .workflow .pipeline import Pipeline
@@ -137,7 +137,7 @@ def feature_store_session(sagemaker_session):
137
137
138
138
@pytest .fixture
139
139
def pipeline_name ():
140
- return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
140
+ return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
141
141
142
142
143
143
@pytest .fixture
@@ -1371,6 +1371,8 @@ def test_tuning_multi_algos(
1371
1371
cpu_instance_type ,
1372
1372
pipeline_name ,
1373
1373
region_name ,
1374
+ script_dir ,
1375
+ athena_dataset_definition ,
1374
1376
):
1375
1377
base_dir = os .path .join (DATA_DIR , "pytorch_mnist" )
1376
1378
entry_point = os .path .join (base_dir , "mnist.py" )
@@ -1382,6 +1384,42 @@ def test_tuning_multi_algos(
1382
1384
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 1 )
1383
1385
instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
1384
1386
1387
+ input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
1388
+
1389
+ sklearn_processor = SKLearnProcessor (
1390
+ framework_version = "0.20.0" ,
1391
+ instance_type = instance_type ,
1392
+ instance_count = instance_count ,
1393
+ base_job_name = "test-sklearn" ,
1394
+ sagemaker_session = sagemaker_session ,
1395
+ role = role ,
1396
+ )
1397
+
1398
+ property_file = PropertyFile (
1399
+ name = "DataAttributes" , output_name = "attributes" , path = "attributes.json"
1400
+ )
1401
+
1402
+ step_process = ProcessingStep (
1403
+ name = "my-process" ,
1404
+ display_name = "ProcessingStep" ,
1405
+ description = "description for Processing step" ,
1406
+ processor = sklearn_processor ,
1407
+ inputs = [
1408
+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
1409
+ ProcessingInput (dataset_definition = athena_dataset_definition ),
1410
+ ],
1411
+ outputs = [
1412
+ ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
1413
+ ProcessingOutput (output_name = "attributes" , source = "/opt/ml/processing/attributes.json" ),
1414
+ ],
1415
+ property_files = [property_file ],
1416
+ code = os .path .join (script_dir , "preprocessing.py" ),
1417
+ )
1418
+
1419
+ static_hp_1 = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
1420
+ json_get_hp = JsonGet (
1421
+ step_name = step_process .name , property_file = property_file , json_path = "train_size"
1422
+ )
1385
1423
pytorch_estimator = PyTorch (
1386
1424
entry_point = entry_point ,
1387
1425
role = role ,
@@ -1392,10 +1430,11 @@ def test_tuning_multi_algos(
1392
1430
sagemaker_session = sagemaker_session ,
1393
1431
enable_sagemaker_metrics = True ,
1394
1432
max_retry_attempts = 3 ,
1433
+ hyperparameters = {"static-hp" : static_hp_1 , "train_size" : json_get_hp },
1395
1434
)
1396
1435
1397
1436
min_batch_size = ParameterString (name = "MinBatchSize" , default_value = "64" )
1398
- max_batch_size = ParameterString ( name = "MaxBatchSize" , default_value = "128" )
1437
+ max_batch_size = json_get_hp
1399
1438
1400
1439
tuner = HyperparameterTuner .create (
1401
1440
estimator_dict = {
@@ -1415,6 +1454,7 @@ def test_tuning_multi_algos(
1415
1454
"estimator-2" : [{"Name" : "test:acc" , "Regex" : "Overall test accuracy: (.*?);" }],
1416
1455
},
1417
1456
)
1457
+
1418
1458
inputs = {
1419
1459
"estimator-1" : TrainingInput (s3_data = input_path ),
1420
1460
"estimator-2" : TrainingInput (s3_data = input_path ),
@@ -1429,7 +1469,7 @@ def test_tuning_multi_algos(
1429
1469
pipeline = Pipeline (
1430
1470
name = pipeline_name ,
1431
1471
parameters = [instance_count , instance_type , min_batch_size , max_batch_size ],
1432
- steps = [step_tune ],
1472
+ steps = [step_process , step_tune ],
1433
1473
sagemaker_session = sagemaker_session ,
1434
1474
)
1435
1475
0 commit comments