Skip to content

more logging info for static pipeline test data setup #3019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions tests/integ/sagemaker/lineage/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
_deploy_static_endpoint(
execution_arn=pipeline_execution_arn, sagemaker_session=sagemaker_session
)

logging.info(f"Using static pipeline {pipeline_execution_arn}")
return pipeline_execution_arn


Expand Down Expand Up @@ -608,16 +608,23 @@ def static_training_job_trial_component(
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
)

model_artifact_arn = static_model_artifact.artifact_arn
query_result = LineageQuery(sagemaker_session).query(
start_arns=[static_model_artifact.artifact_arn],
start_arns=[model_artifact_arn],
query_filter=query_filter,
direction=LineageQueryDirectionEnum.ASCENDANTS,
include_edges=False,
)
logging.info(
f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
)
training_jobs = []
for vertex in query_result.vertices:
training_jobs.append(vertex.to_lineage_object())

if not training_jobs:
raise Exception(f"No training job found for static model artifact {model_artifact_arn}")

return training_jobs[0]


Expand All @@ -643,7 +650,9 @@ def static_transform_job_trial_component(
@pytest.fixture
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
logging.info(f"Using endpoint {endpoint_arn} from static pipeline")

# if the endpoint doesn't exist deploy it
if endpoint_arn is None:
_deploy_static_endpoint(
execution_arn=static_pipeline_execution_arn,
Expand All @@ -664,8 +673,11 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
)
)

endpoint_context = contexts[0]
context_arn = endpoint_context["ContextArn"]
logging.info(f"Using context {context_arn} for static endpoint context")
yield context.EndpointContext.load(
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
endpoint_context["ContextName"], sagemaker_session=sagemaker_session
)


Expand Down Expand Up @@ -717,20 +729,34 @@ def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):

@pytest.fixture
def static_dataset_artifact(static_model_artifact, sagemaker_session):
model_artifact_arn = static_model_artifact.artifact_arn
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
DestinationArn=static_model_artifact.artifact_arn, SourceType="DataSet"
DestinationArn=model_artifact_arn, SourceType="DataSet"
)
logging.info(
f"Found {len(dataset_associations)} associated with model artifact {model_artifact_arn}"
)
if len(dataset_associations["AssociationSummaries"]) == 0:
# no directly associated dataset. work backwards from the model
model_associations = sagemaker_session.sagemaker_client.list_associations(
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
)
DestinationArn=model_artifact_arn, SourceType="Model"
)["AssociationSummaries"]

if len(model_associations) == 0:
raise Exception(f"No models associated with model artifact {model_artifact_arn}")

training_job_associations = sagemaker_session.sagemaker_client.list_associations(
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
DestinationArn=model_associations[0]["SourceArn"],
SourceType="SageMakerTrainingJob",
)
)["AssociationSummaries"]

if len(training_job_associations) == 0:
raise Exception(
f"No training jobs associated with models for model artifact {model_artifact_arn}"
)

dataset_associations = sagemaker_session.sagemaker_client.list_associations(
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
DestinationArn=training_job_associations[0]["SourceArn"],
SourceType="DataSet",
)

Expand Down