Skip to content

Commit c355230

Browse files
committed
feature: Add support for SageMaker lineage queries context
1 parent 87c1d2c commit c355230

File tree

5 files changed

+165
-2
lines changed

5 files changed

+165
-2
lines changed

src/sagemaker/lineage/context.py

+64
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,57 @@ def training_job_arns(
333333
training_job_arns.append(trial_component["Source"]["SourceArn"])
334334
return training_job_arns
335335

336+
def processing_job_arns(
337+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
338+
) -> List[str]:
339+
"""Get ARNs for all processing jobs that appear in the endpoint's lineage.
340+
341+
Returns:
342+
list of str: Processing job ARNs.
343+
"""
344+
query_filter = LineageFilter(
345+
entities=[LineageEntityEnum.TRIAL_COMPONENT],
346+
sources=[LineageSourceEnum.PROCESSING_JOB],
347+
)
348+
query_result = LineageQuery(self.sagemaker_session).query(
349+
start_arns=[self.context_arn],
350+
query_filter=query_filter,
351+
direction=direction,
352+
include_edges=False,
353+
)
354+
processing_job_arns = []
355+
for vertex in query_result.vertices:
356+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
357+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
358+
TrialComponentName=trial_component_name
359+
)
360+
processing_job_arns.append(trial_component["Source"]["SourceArn"])
361+
return processing_job_arns
362+
363+
def trial_components_arns(
364+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
365+
) -> List[str]:
366+
"""Get ARNs for all trial components that appear in the endpoint's lineage.
367+
368+
Returns:
369+
list of str: Trial components ARNs.
370+
"""
371+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
372+
query_result = LineageQuery(self.sagemaker_session).query(
373+
start_arns=[self.context_arn],
374+
query_filter=query_filter,
375+
direction=direction,
376+
include_edges=False,
377+
)
378+
transform_job_arns = []
379+
for vertex in query_result.vertices:
380+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
381+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
382+
TrialComponentName=trial_component_name
383+
)
384+
transform_job_arns.append(trial_component["Source"]["SourceArn"])
385+
return transform_job_arns
386+
336387
def pipeline_execution_arn(
337388
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
338389
) -> str:
@@ -351,3 +402,16 @@ def pipeline_execution_arn(
351402
return tag["Value"]
352403

353404
return None
405+
406+
407+
class ModelPackageGroup(Context):
408+
"""An Amazon SageMaker model package group context, which is part of a SageMaker lineage."""
409+
410+
def pipeline_execution_arn(self) -> str:
411+
"""Get the ARN for the pipeline execution associated with this model package group (if any).
412+
413+
Returns:
414+
str: A pipeline execution ARN.
415+
"""
416+
# return self.properties["PipelineExecutionArn"]
417+
return self.properties.get("PipelineExecutionArn")

src/sagemaker/lineage/query.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class LineageSourceEnum(Enum):
4141
MODEL_REPLACE = "ModelReplaced"
4242
TENSORBOARD = "TensorBoard"
4343
TRAINING_JOB = "TrainingJob"
44+
PROCESSING_JOB = "ProcessingJob"
45+
TRANSFORM_JOB = "TransformJob"
4446

4547

4648
class LineageQueryDirectionEnum(Enum):
@@ -87,6 +89,7 @@ def to_lineage_object(self):
8789
from sagemaker.lineage.artifact import Artifact, ModelArtifact
8890
from sagemaker.lineage.context import Context, EndpointContext
8991
from sagemaker.lineage.artifact import DatasetArtifact
92+
from sagemaker.lineage.action import Action
9093

9194
if self.lineage_entity == LineageEntityEnum.CONTEXT.value:
9295
resource_name = get_resource_name_from_arn(self.arn)
@@ -103,6 +106,8 @@ def to_lineage_object(self):
103106
return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
104107
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
105108

109+
if self.lineage_entity == LineageEntityEnum.ACTION.value:
110+
return Action.load(action_name=self.arn, sagemaker_session=self._session)
106111
raise ValueError("Vertex cannot be converted to a lineage object.")
107112

108113

@@ -155,11 +160,11 @@ def __init__(
155160
def _to_request_dict(self):
156161
"""Convert the lineage filter to its API representation."""
157162
filter_request = {}
158-
if self.entities:
163+
if self.sources:
159164
filter_request["Types"] = list(
160165
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
161166
)
162-
if self.sources:
167+
if self.entities:
163168
filter_request["LineageTypes"] = list(
164169
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
165170
)

tests/integ/sagemaker/lineage/conftest.py

+58
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SLEEP_TIME_SECONDS = 1
3939
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17"
4040
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17"
41+
STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline17ModelPackageGroup"
4142

4243

4344
@pytest.fixture
@@ -543,6 +544,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
543544
)
544545

545546

547+
@pytest.fixture
548+
def static_model_package_group_context(sagemaker_session, static_pipeline_execution_arn):
549+
550+
model_package_group_arn = get_model_package_group_arn_from_static_pipeline(sagemaker_session)
551+
552+
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=model_package_group_arn)[
553+
"ContextSummaries"
554+
]
555+
if len(contexts) != 1:
556+
raise (
557+
Exception(
558+
f"Got an unexpected number of Contexts for \
559+
model package group {STATIC_MODEL_PACKAGE_GROUP_NAME} from pipeline \
560+
execution {static_pipeline_execution_arn}. \
561+
Expected 1 but got {len(contexts)}"
562+
)
563+
)
564+
565+
yield context.ModelPackageGroup.load(
566+
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
567+
)
568+
569+
546570
@pytest.fixture
547571
def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
548572
model_package_arn = get_model_package_arn_from_static_pipeline(
@@ -590,6 +614,31 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session):
590614
)
591615

592616

617+
@pytest.fixture
618+
def static_image_artifact(static_model_artifact, sagemaker_session):
619+
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
620+
DestinationArn=static_model_artifact.artifact_arn, SourceType="Image"
621+
)
622+
if len(dataset_associations["AssociationSummaries"]) == 0:
623+
# no directly associated dataset. work backwards from the model
624+
model_associations = sagemaker_session.sagemaker_client.list_associations(
625+
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
626+
)
627+
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
628+
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
629+
SourceType="SageMakerTrainingJob",
630+
)
631+
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
632+
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
633+
SourceType="Image",
634+
)
635+
636+
yield artifact.ImageArtifact.load(
637+
dataset_associations["AssociationSummaries"][0]["SourceArn"],
638+
sagemaker_session=sagemaker_session,
639+
)
640+
641+
593642
def get_endpoint_arn_from_static_pipeline(sagemaker_session):
594643
try:
595644
endpoint_arn = sagemaker_session.sagemaker_client.describe_endpoint(
@@ -604,6 +653,15 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session):
604653
raise e
605654

606655

656+
def get_model_package_group_arn_from_static_pipeline(sagemaker_session):
657+
static_model_package_group_arn = (
658+
sagemaker_session.sagemaker_client.describe_model_package_group(
659+
ModelPackageGroupName=STATIC_MODEL_PACKAGE_GROUP_NAME
660+
)["ModelPackageGroupArn"]
661+
)
662+
return static_model_package_group_arn
663+
664+
607665
def get_model_package_arn_from_static_pipeline(pipeline_execution_arn, sagemaker_session):
608666
# get the model package ARN from the pipeline
609667
pipeline_execution_steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps(

tests/integ/sagemaker/lineage/test_endpoint_context.py

+16
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ def test_training_job_arns(
4141
assert "training-job" in arn
4242

4343

44+
def test_processing_job_arns(static_endpoint_context):
45+
processing_job_arns = static_endpoint_context.processing_job_arns()
46+
47+
assert len(processing_job_arns) > 0
48+
for arn in processing_job_arns:
49+
assert "processing-job" in arn
50+
51+
52+
def test_trial_components_arns(static_endpoint_context):
53+
trial_components_arns = static_endpoint_context.trial_components_arns()
54+
55+
assert len(trial_components_arns) > 0
56+
for arn in trial_components_arns:
57+
assert "job" in arn
58+
59+
4460
def test_pipeline_execution_arn(static_endpoint_context, static_pipeline_execution_arn):
4561
pipeline_execution_arn = static_endpoint_context.pipeline_execution_arn()
4662

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to test SageMaker ``ModelPackageGroup``"""
14+
from __future__ import absolute_import
15+
16+
17+
def test_pipeline_execution_arn(static_model_package_group_context, static_pipeline_execution_arn):
18+
pipeline_execution_arn = static_model_package_group_context.pipeline_execution_arn()
19+
20+
assert pipeline_execution_arn == static_pipeline_execution_arn

0 commit comments

Comments
 (0)