Skip to content

Commit 430601d

Browse files
committed
more logging info for static pipeline test data setup
1 parent 1c3b9a4 commit 430601d

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

tests/integ/sagemaker/lineage/conftest.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
539539
_deploy_static_endpoint(
540540
execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session
541541
)
542-
542+
logging.info(f"Using static pipeline {pipeline_execution_arn}")
543543
return pipeline_execution_arn
544544

545545

@@ -608,16 +608,23 @@ def static_training_job_trial_component(
608608
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
609609
)
610610

611+
model_artifact_arn = static_model_artifact.artifact_arn
611612
query_result = LineageQuery(sagemaker_session).query(
612-
start_arns=[static_model_artifact.artifact_arn],
613+
start_arns=[model_artifact_arn],
613614
query_filter=query_filter,
614615
direction=LineageQueryDirectionEnum.ASCENDANTS,
615616
include_edges=False,
616617
)
618+
logging.info(
619+
f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
620+
)
617621
training_jobs = []
618622
for vertex in query_result.vertices:
619623
training_jobs.append(vertex.to_lineage_object())
620624

625+
if not training_jobs:
626+
raise Exception(f"No training job found for static model artifact {model_artifact_arn}")
627+
621628
return training_jobs[0]
622629

623630

@@ -643,7 +650,9 @@ def static_transform_job_trial_component(
643650
@pytest.fixture
644651
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
645652
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653+
logging.info(f"Using endpoint {endpoint_arn} from static pipeline")
646654

655+
# if the endpoint doesn't exist deploy it
647656
if endpoint_arn is None:
648657
_deploy_static_endpoint(
649658
execution_arn=static_pipeline_execution_arn,
@@ -664,8 +673,11 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
664673
)
665674
)
666675

676+
endpoint_context = context[0]
677+
context_arn = endpoint_context["ContextArn"]
678+
logging.info(f"Using context {context_arn} for static endpoint context")
667679
yield context.EndpointContext.load(
668-
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
680+
endpoint_context["ContextName"], sagemaker_session=sagemaker_session
669681
)
670682

671683

@@ -709,27 +721,41 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
709721
)
710722
)
711723

712-
yield artifact.ModelArtifact.load(
713-
artifacts[0]["ArtifactArn"], sagemaker_session=sagemaker_session
714-
)
724+
artifact_arn = artifacts[0]["ArtifactArn"]
725+
logging.info(f"Using static model artifact {artifact_arn}")
726+
yield artifact.ModelArtifact.load(artifact_arn, sagemaker_session=sagemaker_session)
715727

716728

717729
@pytest.fixture
718730
def static_dataset_artifact(static_model_artifact, sagemaker_session):
731+
model_artifact_arn = static_model_artifact.artifact_arn
719732
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
720-
DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet"
733+
DestinationArn=model_artifact_arn, SourceType="DataSet"
734+
)
735+
logging.info(
736+
f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}"
721737
)
722738
if len(dataset_associations["AssociationSummaries"]) == 0:
723739
# no directly associated dataset. work backwards from the model
724740
model_associations = sagemaker_session.sagemaker_client.list_associations(
725-
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
726-
)
741+
DestinationArn=model_artifact_arn, SourceType="Model"
742+
)["AssociationSummaries"]
743+
744+
if len(model_associations) == 0:
745+
raise Exception(f"No models associated with model artifact {model_artifact_arn}")
746+
727747
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
728-
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
748+
DestinationArn=model_associations[0]["SourceArn"],
729749
SourceType="SageMakerTrainingJob",
730-
)
750+
)["AssociationSummaries"]
751+
752+
if len(training_job_associations) == 0:
753+
raise Exception(
754+
f"No training jobs associated with models for model artifact {model_artifact_arn}"
755+
)
756+
731757
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
732-
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
758+
DestinationArn=training_job_associations[0]["SourceArn"],
733759
SourceType="DataSet",
734760
)
735761

0 commit comments

Comments
 (0)