67
67
ConditionLessThanOrEqualTo ,
68
68
)
69
69
from sagemaker .workflow .condition_step import ConditionStep
70
- from sagemaker .workflow .callback_step import CallbackStep , CallbackOutput , CallbackOutputTypeEnum
71
- from sagemaker .workflow .lambda_step import LambdaStep , LambdaOutput , LambdaOutputTypeEnum
70
+ from sagemaker .workflow .callback_step import (
71
+ CallbackStep ,
72
+ CallbackOutput ,
73
+ CallbackOutputTypeEnum ,
74
+ )
75
+ from sagemaker .workflow .lambda_step import (
76
+ LambdaStep ,
77
+ LambdaOutput ,
78
+ LambdaOutputTypeEnum ,
79
+ )
72
80
from sagemaker .workflow .emr_step import EMRStep , EMRStepConfig
73
81
from sagemaker .wrangler .processing import DataWranglerProcessor
74
- from sagemaker .dataset_definition .inputs import DatasetDefinition , AthenaDatasetDefinition
82
+ from sagemaker .dataset_definition .inputs import (
83
+ DatasetDefinition ,
84
+ AthenaDatasetDefinition ,
85
+ )
75
86
from sagemaker .workflow .execution_variables import ExecutionVariables
76
87
from sagemaker .workflow .functions import Join , JsonGet
77
88
from sagemaker .wrangler .ingestion import generate_data_ingestion_flow_from_s3_input
92
103
from sagemaker .workflow .step_collections import RegisterModel
93
104
from sagemaker .workflow .pipeline import Pipeline
94
105
from sagemaker .lambda_helper import Lambda
95
- from sagemaker .feature_store .feature_group import FeatureGroup , FeatureDefinition , FeatureTypeEnum
106
+ from sagemaker .feature_store .feature_group import (
107
+ FeatureGroup ,
108
+ FeatureDefinition ,
109
+ FeatureTypeEnum ,
110
+ )
96
111
from tests .integ import DATA_DIR
97
112
from tests .integ .kms_utils import get_or_create_kms_key
98
113
from tests .integ .retry import retries
@@ -262,7 +277,10 @@ def build_jar():
262
277
)
263
278
else :
264
279
subprocess .run (
265
- ["javac" , os .path .join (jar_file_path , java_file_path , "HelloJavaSparkApp.java" )]
280
+ [
281
+ "javac" ,
282
+ os .path .join (jar_file_path , java_file_path , "HelloJavaSparkApp.java" ),
283
+ ]
266
284
)
267
285
268
286
subprocess .run (
@@ -383,10 +401,20 @@ def test_three_step_definition(
383
401
assert set (tuple (param .items ()) for param in definition ["Parameters" ]) == set (
384
402
[
385
403
tuple (
386
- {"Name" : "InstanceType" , "Type" : "String" , "DefaultValue" : "ml.m5.xlarge" }.items ()
404
+ {
405
+ "Name" : "InstanceType" ,
406
+ "Type" : "String" ,
407
+ "DefaultValue" : "ml.m5.xlarge" ,
408
+ }.items ()
387
409
),
388
410
tuple ({"Name" : "InstanceCount" , "Type" : "Integer" , "DefaultValue" : 1 }.items ()),
389
- tuple ({"Name" : "OutputPrefix" , "Type" : "String" , "DefaultValue" : "output" }.items ()),
411
+ tuple (
412
+ {
413
+ "Name" : "OutputPrefix" ,
414
+ "Type" : "String" ,
415
+ "DefaultValue" : "output" ,
416
+ }.items ()
417
+ ),
390
418
]
391
419
)
392
420
@@ -740,7 +768,13 @@ def test_one_step_pyspark_processing_pipeline(
740
768
741
769
742
770
def test_one_step_sparkjar_processing_pipeline (
743
- sagemaker_session , role , cpu_instance_type , pipeline_name , region_name , configuration , build_jar
771
+ sagemaker_session ,
772
+ role ,
773
+ cpu_instance_type ,
774
+ pipeline_name ,
775
+ region_name ,
776
+ configuration ,
777
+ build_jar ,
744
778
):
745
779
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
746
780
cache_config = CacheConfig (enable_caching = True , expire_after = "T30m" )
@@ -758,7 +792,9 @@ def test_one_step_sparkjar_processing_pipeline(
758
792
body = data .read ()
759
793
input_data_uri = f"s3://{ bucket } /spark/input/data.jsonl"
760
794
S3Uploader .upload_string_as_file_body (
761
- body = body , desired_s3_uri = input_data_uri , sagemaker_session = sagemaker_session
795
+ body = body ,
796
+ desired_s3_uri = input_data_uri ,
797
+ sagemaker_session = sagemaker_session ,
762
798
)
763
799
output_data_uri = f"s3://{ bucket } /spark/output/sales/{ datetime .now ().isoformat ()} "
764
800
@@ -877,7 +913,12 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
877
913
878
914
879
915
def test_steps_with_map_params_pipeline (
880
- sagemaker_session , role , script_dir , pipeline_name , region_name , athena_dataset_definition
916
+ sagemaker_session ,
917
+ role ,
918
+ script_dir ,
919
+ pipeline_name ,
920
+ region_name ,
921
+ athena_dataset_definition ,
881
922
):
882
923
instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
883
924
framework_version = "0.20.0"
@@ -1184,7 +1225,8 @@ def test_two_steps_emr_pipeline(sagemaker_session, role, pipeline_name, region_n
1184
1225
response = pipeline .create (role )
1185
1226
create_arn = response ["PipelineArn" ]
1186
1227
assert re .match (
1187
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1228
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1229
+ create_arn ,
1188
1230
)
1189
1231
finally :
1190
1232
try :
@@ -1267,7 +1309,12 @@ def test_conditional_pytorch_training_model_registration(
1267
1309
1268
1310
pipeline = Pipeline (
1269
1311
name = pipeline_name ,
1270
- parameters = [in_condition_input , good_enough_input , instance_count , instance_type ],
1312
+ parameters = [
1313
+ in_condition_input ,
1314
+ good_enough_input ,
1315
+ instance_count ,
1316
+ instance_type ,
1317
+ ],
1271
1318
steps = [step_cond ],
1272
1319
sagemaker_session = sagemaker_session ,
1273
1320
)
@@ -1276,7 +1323,8 @@ def test_conditional_pytorch_training_model_registration(
1276
1323
response = pipeline .create (role )
1277
1324
create_arn = response ["PipelineArn" ]
1278
1325
assert re .match (
1279
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1326
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1327
+ create_arn ,
1280
1328
)
1281
1329
1282
1330
execution = pipeline .start (parameters = {})
@@ -1395,7 +1443,8 @@ def test_tuning_single_algo(
1395
1443
response = pipeline .create (role )
1396
1444
create_arn = response ["PipelineArn" ]
1397
1445
assert re .match (
1398
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1446
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1447
+ create_arn ,
1399
1448
)
1400
1449
1401
1450
execution = pipeline .start (parameters = {})
@@ -1522,7 +1571,8 @@ def test_tuning_multi_algos(
1522
1571
response = pipeline .create (role )
1523
1572
create_arn = response ["PipelineArn" ]
1524
1573
assert re .match (
1525
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1574
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1575
+ create_arn ,
1526
1576
)
1527
1577
1528
1578
execution = pipeline .start (parameters = {})
@@ -1583,7 +1633,8 @@ def test_mxnet_model_registration(
1583
1633
response = pipeline .create (role )
1584
1634
create_arn = response ["PipelineArn" ]
1585
1635
assert re .match (
1586
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1636
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1637
+ create_arn ,
1587
1638
)
1588
1639
1589
1640
execution = pipeline .start (parameters = {})
@@ -1655,10 +1706,14 @@ def test_sklearn_xgboost_sip_model_registration(
1655
1706
destination = train_data_path_param ,
1656
1707
),
1657
1708
ProcessingOutput (
1658
- output_name = "val_data" , source = "/opt/ml/processing/val" , destination = val_data_path_param
1709
+ output_name = "val_data" ,
1710
+ source = "/opt/ml/processing/val" ,
1711
+ destination = val_data_path_param ,
1659
1712
),
1660
1713
ProcessingOutput (
1661
- output_name = "model" , source = "/opt/ml/processing/model" , destination = model_path_param
1714
+ output_name = "model" ,
1715
+ source = "/opt/ml/processing/model" ,
1716
+ destination = model_path_param ,
1662
1717
),
1663
1718
]
1664
1719
@@ -1775,7 +1830,8 @@ def test_sklearn_xgboost_sip_model_registration(
1775
1830
response = pipeline .upsert (role_arn = role )
1776
1831
create_arn = response ["PipelineArn" ]
1777
1832
assert re .match (
1778
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
1833
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
1834
+ create_arn ,
1779
1835
)
1780
1836
1781
1837
execution = pipeline .start (parameters = {})
@@ -1831,7 +1887,9 @@ def test_model_registration_with_drift_check_baselines(
1831
1887
utils .unique_name_from_base ("metrics" ),
1832
1888
)
1833
1889
metrics_uri = S3Uploader .upload_string_as_file_body (
1834
- body = metrics_data , desired_s3_uri = metrics_base_uri , sagemaker_session = sagemaker_session
1890
+ body = metrics_data ,
1891
+ desired_s3_uri = metrics_base_uri ,
1892
+ sagemaker_session = sagemaker_session ,
1835
1893
)
1836
1894
metrics_uri_param = ParameterString (name = "metrics_uri" , default_value = metrics_uri )
1837
1895
@@ -2070,7 +2128,8 @@ def test_model_registration_with_model_repack(
2070
2128
response = pipeline .create (role )
2071
2129
create_arn = response ["PipelineArn" ]
2072
2130
assert re .match (
2073
- rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " , create_arn
2131
+ rf"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
2132
+ create_arn ,
2074
2133
)
2075
2134
2076
2135
execution = pipeline .start (parameters = {})
@@ -2417,13 +2476,17 @@ def test_one_step_ingestion_pipeline(
2417
2476
input_name = "features.csv"
2418
2477
input_file_path = os .path .join (DATA_DIR , "workflow" , "features.csv" )
2419
2478
input_data_uri = os .path .join (
2420
- "s3://" , sagemaker_session .default_bucket (), "py-sdk-ingestion-test-input/features.csv"
2479
+ "s3://" ,
2480
+ sagemaker_session .default_bucket (),
2481
+ "py-sdk-ingestion-test-input/features.csv" ,
2421
2482
)
2422
2483
2423
2484
with open (input_file_path , "r" ) as data :
2424
2485
body = data .read ()
2425
2486
S3Uploader .upload_string_as_file_body (
2426
- body = body , desired_s3_uri = input_data_uri , sagemaker_session = sagemaker_session
2487
+ body = body ,
2488
+ desired_s3_uri = input_data_uri ,
2489
+ sagemaker_session = sagemaker_session ,
2427
2490
)
2428
2491
2429
2492
inputs = [
@@ -2735,7 +2798,9 @@ def test_end_to_end_pipeline_successful_execution(
2735
2798
sagemaker_session = sagemaker_session ,
2736
2799
)
2737
2800
step_transform = TransformStep (
2738
- name = "AbaloneTransform" , transformer = transformer , inputs = TransformInput (data = batch_data )
2801
+ name = "AbaloneTransform" ,
2802
+ transformer = transformer ,
2803
+ inputs = TransformInput (data = batch_data ),
2739
2804
)
2740
2805
2741
2806
# define register model step
0 commit comments