Skip to content

Commit 706a60f

Browse files
committed
more logging info for static pipeline test data setup
1 parent 1c3b9a4 commit 706a60f

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

tests/integ/sagemaker/lineage/conftest.py

+39-14
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
539539
_deploy_static_endpoint(
540540
execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session
541541
)
542-
542+
logging.info(f"Using static pipeline {pipeline_execution_arn}")
543543
return pipeline_execution_arn
544544

545545

@@ -608,16 +608,23 @@ def static_training_job_trial_component(
608608
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
609609
)
610610

611+
model_artifact_arn = static_model_artifact.artifact_arn
611612
query_result = LineageQuery(sagemaker_session).query(
612-
start_arns=[static_model_artifact.artifact_arn],
613+
start_arns=[model_artifact_arn],
613614
query_filter=query_filter,
614615
direction=LineageQueryDirectionEnum.ASCENDANTS,
615616
include_edges=False,
616617
)
618+
logging.info(
619+
f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
620+
)
617621
training_jobs = []
618622
for vertex in query_result.vertices:
619623
training_jobs.append(vertex.to_lineage_object())
620624

625+
if not training_jobs:
626+
raise Exception(f"No training job found for static model artifact {model_artifact_arn}")
627+
621628
return training_jobs[0]
622629

623630

@@ -643,6 +650,7 @@ def static_transform_job_trial_component(
643650
@pytest.fixture
644651
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
645652
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653+
logging.info(f"Using endpoint {endpoint_arn} from static pipeline")
646654

647655
if endpoint_arn is None:
648656
_deploy_static_endpoint(
@@ -651,6 +659,8 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
651659
)
652660
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
653661

662+
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
663+
654664
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=endpoint_arn)[
655665
"ContextSummaries"
656666
]
@@ -664,9 +674,10 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
664674
)
665675
)
666676

667-
yield context.EndpointContext.load(
668-
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
669-
)
677+
endpoint_context = context[0]
678+
context_arn = endpoint_context["ContextArn"]
679+
logging.info(f"Using context {context_arn} for static endpoint context")
680+
yield context.EndpointContext.load(endpoint_context["ContextName"], sagemaker_session=sagemaker_session)
670681

671682

672683
@pytest.fixture
@@ -709,27 +720,41 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
709720
)
710721
)
711722

712-
yield artifact.ModelArtifact.load(
713-
artifacts[0]["ArtifactArn"], sagemaker_session=sagemaker_session
714-
)
723+
artifact_arn = artifacts[0]["ArtifactArn"]
724+
logging.info(f"Using static model artifact {artifact_arn}")
725+
yield artifact.ModelArtifact.load(artifact_arn, sagemaker_session=sagemaker_session)
715726

716727

717728
@pytest.fixture
718729
def static_dataset_artifact(static_model_artifact, sagemaker_session):
730+
model_artifact_arn = static_model_artifact.artifact_arn
719731
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
720-
DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet"
732+
DestinationArn=model_artifact_arn, SourceType="DataSet"
733+
)
734+
logging.info(
735+
f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}"
721736
)
722737
if len(dataset_associations["AssociationSummaries"]) == 0:
723738
# no directly associated dataset. work backwards from the model
724739
model_associations = sagemaker_session.sagemaker_client.list_associations(
725-
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
726-
)
740+
DestinationArn=model_artifact_arn, SourceType="Model"
741+
)["AssociationSummaries"]
742+
743+
if len(model_associations) == 0:
744+
raise Exception(f"No models associated with model artifact {model_artifact_arn}")
745+
727746
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
728-
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
747+
DestinationArn=model_associations[0]["SourceArn"],
729748
SourceType="SageMakerTrainingJob",
730-
)
749+
)["AssociationSummaries"]
750+
751+
if len(training_job_associations) == 0:
752+
raise Exception(
753+
f"No training jobs associated with models for model artifact {model_artifact_arn}"
754+
)
755+
731756
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
732-
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
757+
DestinationArn=training_job_associations[0]["SourceArn"],
733758
SourceType="DataSet",
734759
)
735760

0 commit comments

Comments
 (0)