26
26
_REPACK_MODEL_NAME_BASE ,
27
27
)
28
28
from sagemaker .workflow .parameters import ParameterString
29
- from sagemaker .workflow .pipeline import Pipeline
29
+ from sagemaker .workflow .pipeline import Pipeline , PipelineGraph
30
30
from sagemaker .workflow .pipeline_context import PipelineSession
31
31
from sagemaker .workflow .utilities import list_to_request
32
32
from tests .unit import DATA_DIR
@@ -268,7 +268,7 @@ def test_step_collection_properties(pipeline_session, sagemaker_session):
268
268
steps = register_model .steps
269
269
assert len (steps ) == 1
270
270
assert register_model .properties .ModelPackageName .expr == {
271
- "Get" : f"Steps.{ register_model_step_name } .ModelPackageName"
271
+ "Get" : f"Steps.{ register_model_step_name } -RegisterModel .ModelPackageName"
272
272
}
273
273
274
274
# Custom StepCollection
@@ -330,10 +330,9 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
330
330
step_list = json .loads (pipeline .definition ())["Steps" ]
331
331
assert len (step_list ) == 7
332
332
for step in step_list :
333
- if step ["Name" ] not in ["MyStep2" , "MyStep3" , f"{ model_name } RepackModel" ]:
333
+ if step ["Name" ] not in ["MyStep2" , "MyStep3" , f"{ model_name } - RepackModel" ]:
334
334
assert "DependsOn" not in step
335
- continue
336
- if step ["Name" ] == f"{ model_name } RepackModel" :
335
+ elif step ["Name" ] == f"{ model_name } -RepackModel" :
337
336
assert set (step ["DependsOn" ]) == {
338
337
"MyStep1" ,
339
338
f"{ model_step_name } -{ _REPACK_MODEL_NAME_BASE } -{ model_name } " ,
@@ -344,9 +343,21 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
344
343
"MyStep1" ,
345
344
f"{ model_step_name } -{ _REPACK_MODEL_NAME_BASE } -{ model_name } " ,
346
345
f"{ model_step_name } -{ _CREATE_MODEL_NAME_BASE } " ,
347
- f"{ model_name } RepackModel" ,
348
- register_model_name ,
346
+ f"{ model_name } - RepackModel" ,
347
+ f" { register_model_name } -RegisterModel" ,
349
348
}
349
+ adjacency_list = PipelineGraph .from_pipeline (pipeline ).adjacency_list
350
+ assert ordered (adjacency_list ) == ordered (
351
+ {
352
+ "MyStep1" : ["MyStep2" , "MyStep3" , "MyModel-RepackModel" ],
353
+ "MyStep2" : [],
354
+ "MyStep3" : [],
355
+ "MyModelStep-RepackModel-MyModel" : ["MyModelStep-CreateModel" ],
356
+ "MyModelStep-CreateModel" : ["MyStep2" , "MyStep3" , "MyModel-RepackModel" ],
357
+ "MyModel-RepackModel" : [],
358
+ "RegisterModelStep-RegisterModel" : ["MyStep2" , "MyStep3" ],
359
+ }
360
+ )
350
361
351
362
352
363
def test_register_model (estimator , model_metrics , drift_check_baselines ):
@@ -378,7 +389,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
378
389
assert ordered (register_model .request_dicts ()) == ordered (
379
390
[
380
391
{
381
- "Name" : "RegisterModelStep" ,
392
+ "Name" : "RegisterModelStep-RegisterModel " ,
382
393
"Type" : "RegisterModel" ,
383
394
"DependsOn" : ["TestStep" ],
384
395
"DisplayName" : "RegisterModelStep" ,
@@ -450,7 +461,7 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
450
461
assert ordered (register_model .request_dicts ()) == ordered (
451
462
[
452
463
{
453
- "Name" : "RegisterModelStep" ,
464
+ "Name" : "RegisterModelStep-RegisterModel " ,
454
465
"Type" : "RegisterModel" ,
455
466
"Description" : "description" ,
456
467
"Arguments" : {
@@ -526,7 +537,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines):
526
537
assert ordered (register_model .request_dicts ()) == ordered (
527
538
[
528
539
{
529
- "Name" : "RegisterModelStep" ,
540
+ "Name" : "RegisterModelStep-RegisterModel " ,
530
541
"Type" : "RegisterModel" ,
531
542
"Description" : "description" ,
532
543
"DependsOn" : ["TestStep" ],
@@ -618,7 +629,7 @@ def test_register_model_with_model_repack_with_estimator(
618
629
619
630
for request_dict in request_dicts :
620
631
if request_dict ["Type" ] == "Training" :
621
- assert request_dict ["Name" ] == "RegisterModelStepRepackModel "
632
+ assert request_dict ["Name" ] == "RegisterModelStep-RepackModel "
622
633
assert len (request_dict ["DependsOn" ]) == 1
623
634
assert request_dict ["DependsOn" ][0 ] == "TestStep"
624
635
arguments = request_dict ["Arguments" ]
@@ -671,7 +682,7 @@ def test_register_model_with_model_repack_with_estimator(
671
682
}
672
683
)
673
684
elif request_dict ["Type" ] == "RegisterModel" :
674
- assert request_dict ["Name" ] == "RegisterModelStep"
685
+ assert request_dict ["Name" ] == "RegisterModelStep-RegisterModel "
675
686
assert "DependsOn" not in request_dict
676
687
arguments = request_dict ["Arguments" ]
677
688
assert len (arguments ["InferenceSpecification" ]["Containers" ]) == 1
@@ -745,7 +756,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
745
756
746
757
for request_dict in request_dicts :
747
758
if request_dict ["Type" ] == "Training" :
748
- assert request_dict ["Name" ] == "modelNameRepackModel "
759
+ assert request_dict ["Name" ] == "modelName-RepackModel "
749
760
assert len (request_dict ["DependsOn" ]) == 1
750
761
assert request_dict ["DependsOn" ][0 ] == "TestStep"
751
762
arguments = request_dict ["Arguments" ]
@@ -798,7 +809,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift
798
809
}
799
810
)
800
811
elif request_dict ["Type" ] == "RegisterModel" :
801
- assert request_dict ["Name" ] == "RegisterModelStep"
812
+ assert request_dict ["Name" ] == "RegisterModelStep-RegisterModel "
802
813
assert "DependsOn" not in request_dict
803
814
arguments = request_dict ["Arguments" ]
804
815
assert len (arguments ["InferenceSpecification" ]["Containers" ]) == 1
@@ -874,7 +885,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
874
885
875
886
for request_dict in request_dicts :
876
887
if request_dict ["Type" ] == "Training" :
877
- assert request_dict ["Name" ] == "modelNameRepackModel "
888
+ assert request_dict ["Name" ] == "modelName-RepackModel "
878
889
assert len (request_dict ["DependsOn" ]) == 1
879
890
assert request_dict ["DependsOn" ][0 ] == "TestStep"
880
891
arguments = request_dict ["Arguments" ]
@@ -927,7 +938,7 @@ def test_register_model_with_model_repack_with_pipeline_model(
927
938
}
928
939
)
929
940
elif request_dict ["Type" ] == "RegisterModel" :
930
- assert request_dict ["Name" ] == "RegisterModelStep"
941
+ assert request_dict ["Name" ] == "RegisterModelStep-RegisterModel "
931
942
assert "DependsOn" not in request_dict
932
943
arguments = request_dict ["Arguments" ]
933
944
assert len (arguments ["InferenceSpecification" ]["Containers" ]) == 1
0 commit comments