38
38
SLEEP_TIME_SECONDS = 1
39
39
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17"
40
40
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17"
41
+ STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline17ModelPackageGroup"
41
42
42
43
43
44
@pytest .fixture
@@ -543,6 +544,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
543
544
)
544
545
545
546
547
+ @pytest .fixture
548
+ def static_model_package_group_context (sagemaker_session , static_pipeline_execution_arn ):
549
+
550
+ model_package_group_arn = get_model_package_group_arn_from_static_pipeline (sagemaker_session )
551
+
552
+ contexts = sagemaker_session .sagemaker_client .list_contexts (SourceUri = model_package_group_arn )[
553
+ "ContextSummaries"
554
+ ]
555
+ if len (contexts ) != 1 :
556
+ raise (
557
+ Exception (
558
+ f"Got an unexpected number of Contexts for \
559
+ model package group { STATIC_MODEL_PACKAGE_GROUP_NAME } from pipeline \
560
+ execution { static_pipeline_execution_arn } . \
561
+ Expected 1 but got { len (contexts )} "
562
+ )
563
+ )
564
+
565
+ yield context .ModelPackageGroup .load (
566
+ contexts [0 ]["ContextName" ], sagemaker_session = sagemaker_session
567
+ )
568
+
569
+
546
570
@pytest .fixture
547
571
def static_model_artifact (sagemaker_session , static_pipeline_execution_arn ):
548
572
model_package_arn = get_model_package_arn_from_static_pipeline (
@@ -590,6 +614,31 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session):
590
614
)
591
615
592
616
617
+ @pytest .fixture
618
+ def static_image_artifact (static_model_artifact , sagemaker_session ):
619
+ dataset_associations = sagemaker_session .sagemaker_client .list_associations (
620
+ DestinationArn = static_model_artifact .artifact_arn , SourceType = "Image"
621
+ )
622
+ if len (dataset_associations ["AssociationSummaries" ]) == 0 :
623
+ # no directly associated dataset. work backwards from the model
624
+ model_associations = sagemaker_session .sagemaker_client .list_associations (
625
+ DestinationArn = static_model_artifact .artifact_arn , SourceType = "Model"
626
+ )
627
+ training_job_associations = sagemaker_session .sagemaker_client .list_associations (
628
+ DestinationArn = model_associations ["AssociationSummaries" ][0 ]["SourceArn" ],
629
+ SourceType = "SageMakerTrainingJob" ,
630
+ )
631
+ dataset_associations = sagemaker_session .sagemaker_client .list_associations (
632
+ DestinationArn = training_job_associations ["AssociationSummaries" ][0 ]["SourceArn" ],
633
+ SourceType = "Image" ,
634
+ )
635
+
636
+ yield artifact .ImageArtifact .load (
637
+ dataset_associations ["AssociationSummaries" ][0 ]["SourceArn" ],
638
+ sagemaker_session = sagemaker_session ,
639
+ )
640
+
641
+
593
642
def get_endpoint_arn_from_static_pipeline (sagemaker_session ):
594
643
try :
595
644
endpoint_arn = sagemaker_session .sagemaker_client .describe_endpoint (
@@ -604,6 +653,15 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session):
604
653
raise e
605
654
606
655
656
+ def get_model_package_group_arn_from_static_pipeline (sagemaker_session ):
657
+ static_model_package_group_arn = (
658
+ sagemaker_session .sagemaker_client .describe_model_package_group (
659
+ ModelPackageGroupName = STATIC_MODEL_PACKAGE_GROUP_NAME
660
+ )["ModelPackageGroupArn" ]
661
+ )
662
+ return static_model_package_group_arn
663
+
664
+
607
665
def get_model_package_arn_from_static_pipeline (pipeline_execution_arn , sagemaker_session ):
608
666
# get the model package ARN from the pipeline
609
667
pipeline_execution_steps = sagemaker_session .sagemaker_client .list_pipeline_execution_steps (
0 commit comments