Skip to content

Commit 24368f6

Browse files
committed
more logging info for static pipeline test data setup
1 parent 1c3b9a4 commit 24368f6

File tree

1 file changed

+39
-12
lines changed

1 file changed

+39
-12
lines changed

tests/integ/sagemaker/lineage/conftest.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
539539
_deploy_static_endpoint(
540540
execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session
541541
)
542-
542+
logging.info(f"Using static pipeline {pipeline_execution_arn}")
543543
return pipeline_execution_arn
544544

545545

@@ -608,16 +608,23 @@ def static_training_job_trial_component(
608608
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
609609
)
610610

611+
model_artifact_arn = static_model_artifact.artifact_arn
611612
query_result = LineageQuery(sagemaker_session).query(
612-
start_arns=[static_model_artifact.artifact_arn],
613+
start_arns=[model_artifact_arn],
613614
query_filter=query_filter,
614615
direction=LineageQueryDirectionEnum.ASCENDANTS,
615616
include_edges=False,
616617
)
618+
logging.info(
619+
f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
620+
)
617621
training_jobs = []
618622
for vertex in query_result.vertices:
619623
training_jobs.append(vertex.to_lineage_object())
620624

625+
if not training_jobs:
626+
raise Exception(f"No training job found for static model artifact {model_artifact_arn}")
627+
621628
return training_jobs[0]
622629

623630

@@ -643,6 +650,7 @@ def static_transform_job_trial_component(
643650
@pytest.fixture
644651
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
645652
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653+
logging.info(f"Using endpoint {endpoint_arn} from static pipeline")
646654

647655
if endpoint_arn is None:
648656
_deploy_static_endpoint(
@@ -651,6 +659,8 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
651659
)
652660
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653661

662+
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
663+
654664
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=endpoint_arn)[
655665
"ContextSummaries"
656666
]
@@ -664,8 +674,11 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
664674
)
665675
)
666676

677+
endpoint_context = context[0]
678+
context_arn = endpoint_context["ContextArn"]
679+
logging.info(f"Using context {context_arn} for static endpoint context")
667680
yield context.EndpointContext.load(
668-
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
681+
endpoint_context["ContextName"], sagemaker_session=sagemaker_session
669682
)
670683

671684

@@ -709,27 +722,41 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
709722
)
710723
)
711724

712-
yield artifact.ModelArtifact.load(
713-
artifacts[0]["ArtifactArn"], sagemaker_session=sagemaker_session
714-
)
725+
artifact_arn = artifacts[0]["ArtifactArn"]
726+
logging.info(f"Using static model artifact {artifact_arn}")
727+
yield artifact.ModelArtifact.load(artifact_arn, sagemaker_session=sagemaker_session)
715728

716729

717730
@pytest.fixture
718731
def static_dataset_artifact(static_model_artifact, sagemaker_session):
732+
model_artifact_arn = static_model_artifact.artifact_arn
719733
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
720-
DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet"
734+
DestinationArn=model_artifact_arn, SourceType="DataSet"
735+
)
736+
logging.info(
737+
f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}"
721738
)
722739
if len(dataset_associations["AssociationSummaries"]) == 0:
723740
# no directly associated dataset. work backwards from the model
724741
model_associations = sagemaker_session.sagemaker_client.list_associations(
725-
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
726-
)
742+
DestinationArn=model_artifact_arn, SourceType="Model"
743+
)["AssociationSummaries"]
744+
745+
if len(model_associations) == 0:
746+
raise Exception(f"No models associated with model artifact {model_artifact_arn}")
747+
727748
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
728-
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
749+
DestinationArn=model_associations[0]["SourceArn"],
729750
SourceType="SageMakerTrainingJob",
730-
)
751+
)["AssociationSummaries"]
752+
753+
if len(training_job_associations) == 0:
754+
raise Exception(
755+
f"No training jobs associated with models for model artifact {model_artifact_arn}"
756+
)
757+
731758
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
732-
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
759+
DestinationArn=training_job_associations[0]["SourceArn"],
733760
SourceType="DataSet",
734761
)
735762

0 commit comments

Comments
 (0)