Skip to content

Commit bc3825e

Browse files
authored
more logging info for static pipeline test data setup (#3019)
1 parent 686bcb2 commit bc3825e

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

tests/integ/sagemaker/lineage/conftest.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
539539
_deploy_static_endpoint(
540540
execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session
541541
)
542-
542+
logging.info(f"Using static pipeline {pipeline_execution_arn}")
543543
return pipeline_execution_arn
544544

545545

@@ -608,16 +608,23 @@ def static_training_job_trial_component(
608608
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
609609
)
610610

611+
model_artifact_arn = static_model_artifact.artifact_arn
611612
query_result = LineageQuery(sagemaker_session).query(
612-
start_arns=[static_model_artifact.artifact_arn],
613+
start_arns=[model_artifact_arn],
613614
query_filter=query_filter,
614615
direction=LineageQueryDirectionEnum.ASCENDANTS,
615616
include_edges=False,
616617
)
618+
logging.info(
619+
f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
620+
)
617621
training_jobs = []
618622
for vertex in query_result.vertices:
619623
training_jobs.append(vertex.to_lineage_object())
620624

625+
if not training_jobs:
626+
raise Exception(f"No training job found for static model artifact {model_artifact_arn}")
627+
621628
return training_jobs[0]
622629

623630

@@ -643,7 +650,9 @@ def static_transform_job_trial_component(
643650
@pytest.fixture
644651
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
645652
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653+
logging.info(f"Using endpoint {endpoint_arn} from static pipeline")
646654

655+
# if the endpoint doesn't exist deploy it
647656
if endpoint_arn is None:
648657
_deploy_static_endpoint(
649658
execution_arn=static_pipeline_execution_arn,
@@ -664,8 +673,11 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
664673
)
665674
)
666675

676+
endpoint_context = contexts[0]
677+
context_arn = endpoint_context["ContextArn"]
678+
logging.info(f"Using context {context_arn} for static endpoint context")
667679
yield context.EndpointContext.load(
668-
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
680+
endpoint_context["ContextName"], sagemaker_session=sagemaker_session
669681
)
670682

671683

@@ -717,20 +729,34 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
717729

718730
@pytest.fixture
719731
def static_dataset_artifact(static_model_artifact, sagemaker_session):
732+
model_artifact_arn = static_model_artifact.artifact_arn
720733
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
721-
DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet"
734+
DestinationArn=model_artifact_arn, SourceType="DataSet"
735+
)
736+
logging.info(
737+
f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}"
722738
)
723739
if len(dataset_associations["AssociationSummaries"]) == 0:
724740
# no directly associated dataset. work backwards from the model
725741
model_associations = sagemaker_session.sagemaker_client.list_associations(
726-
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
727-
)
742+
DestinationArn=model_artifact_arn, SourceType="Model"
743+
)["AssociationSummaries"]
744+
745+
if len(model_associations) == 0:
746+
raise Exception(f"No models associated with model artifact {model_artifact_arn}")
747+
728748
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
729-
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
749+
DestinationArn=model_associations[0]["SourceArn"],
730750
SourceType="SageMakerTrainingJob",
731-
)
751+
)["AssociationSummaries"]
752+
753+
if len(training_job_associations) == 0:
754+
raise Exception(
755+
f"No training jobs associated with models for model artifact {model_artifact_arn}"
756+
)
757+
732758
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
733-
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
759+
DestinationArn=training_job_associations[0]["SourceArn"],
734760
SourceType="DataSet",
735761
)
736762

0 commit comments

Comments
 (0)