@@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
539
539
_deploy_static_endpoint (
540
540
execution_arn = pipeline_execution_arn , sagemaker_session = sagemaker_session
541
541
)
542
-
542
+ logging . info ( f"Using static pipeline { pipeline_execution_arn } " )
543
543
return pipeline_execution_arn
544
544
545
545
@@ -608,16 +608,23 @@ def static_training_job_trial_component(
608
608
entities = [LineageEntityEnum .TRIAL_COMPONENT ], sources = [LineageSourceEnum .TRAINING_JOB ]
609
609
)
610
610
611
+ model_artifact_arn = static_model_artifact .artifact_arn
611
612
query_result = LineageQuery (sagemaker_session ).query (
612
- start_arns = [static_model_artifact . artifact_arn ],
613
+ start_arns = [model_artifact_arn ],
613
614
query_filter = query_filter ,
614
615
direction = LineageQueryDirectionEnum .ASCENDANTS ,
615
616
include_edges = False ,
616
617
)
618
+ logging .info (
619
+ f"Found { len (query_result .vertices )} trial components from model artifact { model_artifact_arn } "
620
+ )
617
621
training_jobs = []
618
622
for vertex in query_result .vertices :
619
623
training_jobs .append (vertex .to_lineage_object ())
620
624
625
+ if not training_jobs :
626
+ raise Exception (f"No training job found for static model artifact { model_artifact_arn } " )
627
+
621
628
return training_jobs [0 ]
622
629
623
630
@@ -642,7 +649,7 @@ def static_transform_job_trial_component(
642
649
643
650
@pytest .fixture
644
651
def static_endpoint_context (sagemaker_session , static_pipeline_execution_arn ):
645
- endpoint_arn = get_endpoint_arn_from_static_pipeline ( sagemaker_session )
652
+ logging . info ( f"Using endpoint { endpoint_arn } from static pipeline" )
646
653
647
654
if endpoint_arn is None :
648
655
_deploy_static_endpoint (
@@ -651,6 +658,8 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
651
658
)
652
659
endpoint_arn = get_endpoint_arn_from_static_pipeline (sagemaker_session )
653
660
661
+ endpoint_arn = get_endpoint_arn_from_static_pipeline (sagemaker_session )
662
+
654
663
contexts = sagemaker_session .sagemaker_client .list_contexts (SourceUri = endpoint_arn )[
655
664
"ContextSummaries"
656
665
]
@@ -663,10 +672,10 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
663
672
Expected 1 but got { len (contexts )} "
664
673
)
665
674
)
666
-
667
- yield context . EndpointContext . load (
668
- contexts [ 0 ][ "ContextName" ], sagemaker_session = sagemaker_session
669
- )
675
+ context = context [ 0 ]
676
+ context_arn = context [ "ContextArn" ]
677
+ logging . info ( f"Using context { context_arn } for static endpoint context" )
678
+ yield context . EndpointContext . load ( context [ "ContextName" ], sagemaker_session = sagemaker_session )
670
679
671
680
672
681
@pytest .fixture
@@ -709,27 +718,41 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
709
718
)
710
719
)
711
720
712
- yield artifact . ModelArtifact . load (
713
- artifacts [ 0 ][ "ArtifactArn" ], sagemaker_session = sagemaker_session
714
- )
721
+ artifact_arn = artifacts [ 0 ][ "ArtifactArn" ]
722
+ logging . info ( f"Using static model artifact { artifact_arn } " )
723
+ yield artifact . ModelArtifact . load ( artifact_arn , sagemaker_session = sagemaker_session )
715
724
716
725
717
726
@pytest .fixture
718
727
def static_dataset_artifact (static_model_artifact , sagemaker_session ):
728
+ model_artifact_arn = static_model_artifact .artifact_arn
719
729
dataset_associations = sagemaker_session .sagemaker_client .list_associations (
720
- DestinationArn = static_model_artifact .artifact_arn , SourceType = "DataSet"
730
+ DestinationArn = model_artifact_arn , SourceType = "DataSet"
731
+ )
732
+ logging .info (
733
+ f"Found { len (dataset_associations )} associated with model artifact { model_artifact_arn } "
721
734
)
722
735
if len (dataset_associations ["AssociationSummaries" ]) == 0 :
723
736
# no directly associated dataset. work backwards from the model
724
737
model_associations = sagemaker_session .sagemaker_client .list_associations (
725
- DestinationArn = static_model_artifact .artifact_arn , SourceType = "Model"
726
- )
738
+ DestinationArn = model_artifact_arn , SourceType = "Model"
739
+ )["AssociationSummaries" ]
740
+
741
+ if len (model_associations ) == 0 :
742
+ raise Exception (f"No models associated with model artifact { model_artifact_arn } " )
743
+
727
744
training_job_associations = sagemaker_session .sagemaker_client .list_associations (
728
- DestinationArn = model_associations ["AssociationSummaries" ][ 0 ]["SourceArn" ],
745
+ DestinationArn = model_associations [0 ]["SourceArn" ],
729
746
SourceType = "SageMakerTrainingJob" ,
730
- )
747
+ )["AssociationSummaries" ]
748
+
749
+ if len (training_job_associations ) == 0 :
750
+ raise Exception (
751
+ f"No training jobs associated with models for model artifact { model_artifact_arn } "
752
+ )
753
+
731
754
dataset_associations = sagemaker_session .sagemaker_client .list_associations (
732
- DestinationArn = training_job_associations ["AssociationSummaries" ][ 0 ]["SourceArn" ],
755
+ DestinationArn = training_job_associations [0 ]["SourceArn" ],
733
756
SourceType = "DataSet" ,
734
757
)
735
758
0 commit comments