From ad4edcd9023bfa45edfbd8f399c08e849df67895 Mon Sep 17 00:00:00 2001 From: Dana Benson Date: Tue, 22 Mar 2022 09:57:02 -0700 Subject: [PATCH] more logging info for static pipeline test data setup --- tests/integ/sagemaker/lineage/conftest.py | 44 ++++++++++++++++++----- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 45954dd5c6..a03108e044 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session): _deploy_static_endpoint( execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session ) - + logging.info(f"Using static pipeline {pipeline_execution_arn}") return pipeline_execution_arn @@ -608,16 +608,23 @@ def static_training_job_trial_component( entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB] ) + model_artifact_arn = static_model_artifact.artifact_arn query_result = LineageQuery(sagemaker_session).query( - start_arns=[static_model_artifact.artifact_arn], + start_arns=[model_artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, ) + logging.info( + f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}" + ) training_jobs = [] for vertex in query_result.vertices: training_jobs.append(vertex.to_lineage_object()) + if not training_jobs: + raise Exception(f"No training job found for static model artifact {model_artifact_arn}") + return training_jobs[0] @@ -643,7 +650,9 @@ def static_transform_job_trial_component( @pytest.fixture def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session) + logging.info(f"Using endpoint {endpoint_arn} from static pipeline") + # if the endpoint doesn't exist deploy it if endpoint_arn is None: _deploy_static_endpoint( execution_arn=static_pipeline_execution_arn, @@ -664,8 +673,11 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn): ) ) + endpoint_context = contexts[0] + context_arn = endpoint_context["ContextArn"] + logging.info(f"Using context {context_arn} for static endpoint context") yield context.EndpointContext.load( - contexts[0]["ContextName"], sagemaker_session=sagemaker_session + endpoint_context["ContextName"], sagemaker_session=sagemaker_session ) @@ -717,20 +729,34 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn): @pytest.fixture def static_dataset_artifact(static_model_artifact, sagemaker_session): + model_artifact_arn = static_model_artifact.artifact_arn dataset_associations = sagemaker_session.sagemaker_client.list_associations( - DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet" + DestinationArn=model_artifact_arn, SourceType="DataSet" + ) + logging.info( + f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}" ) if len(dataset_associations["AssociationSummaries"]) == 0: # no directly associated dataset. work backwards from the model model_associations = sagemaker_session.sagemaker_client.list_associations( - DestinationArn=static_model_artifact.artifact_arn, SourceType="Model" - ) + DestinationArn=model_artifact_arn, SourceType="Model" + )["AssociationSummaries"] + + if len(model_associations) == 0: + raise Exception(f"No models associated with model artifact {model_artifact_arn}") + training_job_associations = sagemaker_session.sagemaker_client.list_associations( - DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"], + DestinationArn=model_associations[0]["SourceArn"], SourceType="SageMakerTrainingJob", - ) + )["AssociationSummaries"] + + if len(training_job_associations) == 0: + raise Exception( + f"No training jobs associated with models for model artifact {model_artifact_arn}" + ) + dataset_associations = sagemaker_session.sagemaker_client.list_associations( - DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"], + DestinationArn=training_job_associations[0]["SourceArn"], SourceType="DataSet", )