Skip to content

Commit 366f586

Browse files
committed
more logging info for static pipeline test data setup
1 parent 1c3b9a4 commit 366f586

File tree

1 file changed

+39
-16
lines changed

1 file changed

+39
-16
lines changed

tests/integ/sagemaker/lineage/conftest.py

+39-16
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
539539
_deploy_static_endpoint(
540540
execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session
541541
)
542-
542+
logging.info(f"Using static pipeline {pipeline_execution_arn}")
543543
return pipeline_execution_arn
544544

545545

@@ -608,16 +608,23 @@ def static_training_job_trial_component(
608608
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
609609
)
610610

611+
model_artifact_arn = static_model_artifact.artifact_arn
611612
query_result = LineageQuery(sagemaker_session).query(
612-
start_arns=[static_model_artifact.artifact_arn],
613+
start_arns=[model_artifact_arn],
613614
query_filter=query_filter,
614615
direction=LineageQueryDirectionEnum.ASCENDANTS,
615616
include_edges=False,
616617
)
618+
logging.info(
619+
f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
620+
)
617621
training_jobs = []
618622
for vertex in query_result.vertices:
619623
training_jobs.append(vertex.to_lineage_object())
620624

625+
if not training_jobs:
626+
raise Exception(f"No training job found for static model artifact {model_artifact_arn}")
627+
621628
return training_jobs[0]
622629

623630

@@ -642,7 +649,7 @@ def static_transform_job_trial_component(
642649

643650
@pytest.fixture
644651
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
645-
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
652+
logging.info(f"Using endpoint {endpoint_arn} from static pipeline")
646653

647654
if endpoint_arn is None:
648655
_deploy_static_endpoint(
@@ -651,6 +658,8 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
651658
)
652659
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653660

661+
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
662+
654663
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=endpoint_arn)[
655664
"ContextSummaries"
656665
]
@@ -663,10 +672,10 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
663672
Expected 1 but got {len(contexts)}"
664673
)
665674
)
666-
667-
yield context.EndpointContext.load(
668-
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
669-
)
675+
context = context[0]
676+
context_arn = context["ContextArn"]
677+
logging.info(f"Using context {context_arn} for static endpoint context")
678+
yield context.EndpointContext.load(context["ContextName"], sagemaker_session=sagemaker_session)
670679

671680

672681
@pytest.fixture
@@ -709,27 +718,41 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
709718
)
710719
)
711720

712-
yield artifact.ModelArtifact.load(
713-
artifacts[0]["ArtifactArn"], sagemaker_session=sagemaker_session
714-
)
721+
artifact_arn = artifacts[0]["ArtifactArn"]
722+
logging.info(f"Using static model artifact {artifact_arn}")
723+
yield artifact.ModelArtifact.load(artifact_arn, sagemaker_session=sagemaker_session)
715724

716725

717726
@pytest.fixture
718727
def static_dataset_artifact(static_model_artifact, sagemaker_session):
728+
model_artifact_arn = static_model_artifact.artifact_arn
719729
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
720-
DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet"
730+
DestinationArn=model_artifact_arn, SourceType="DataSet"
731+
)
732+
logging.info(
733+
f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}"
721734
)
722735
if len(dataset_associations["AssociationSummaries"]) == 0:
723736
# no directly associated dataset. work backwards from the model
724737
model_associations = sagemaker_session.sagemaker_client.list_associations(
725-
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
726-
)
738+
DestinationArn=model_artifact_arn, SourceType="Model"
739+
)["AssociationSummaries"]
740+
741+
if len(model_associations) == 0:
742+
raise Exception(f"No models associated with model artifact {model_artifact_arn}")
743+
727744
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
728-
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
745+
DestinationArn=model_associations[0]["SourceArn"],
729746
SourceType="SageMakerTrainingJob",
730-
)
747+
)["AssociationSummaries"]
748+
749+
if len(training_job_associations) == 0:
750+
raise Exception(
751+
f"No training jobs associated with models for model artifact {model_artifact_arn}"
752+
)
753+
731754
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
732-
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
755+
DestinationArn=training_job_associations[0]["SourceArn"],
733756
SourceType="DataSet",
734757
)
735758

0 commit comments

Comments
 (0)