Skip to content

Commit 95fae83

Browse files
committed
feature: Add support for SageMaker lineage queries context
1 parent 72d1246 commit 95fae83

File tree

5 files changed

+165
-4
lines changed

5 files changed

+165
-4
lines changed

src/sagemaker/lineage/context.py

+64
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,58 @@ def training_job_arns(
333333
training_job_arns.append(trial_component["Source"]["SourceArn"])
334334
return training_job_arns
335335

336+
def processing_job_arns(
337+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
338+
) -> List[str]:
339+
"""Get ARNs for all processing jobs that appear in the endpoint's lineage.
340+
341+
Returns:
342+
list of str: Processing job ARNs.
343+
"""
344+
query_filter = LineageFilter(
345+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB] #todo
346+
)
347+
query_result = LineageQuery(self.sagemaker_session).query(
348+
start_arns=[self.context_arn],
349+
query_filter=query_filter,
350+
direction=direction,
351+
include_edges=False,
352+
)
353+
processing_job_arns = []
354+
for vertex in query_result.vertices:
355+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
356+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
357+
TrialComponentName=trial_component_name
358+
)
359+
processing_job_arns.append(trial_component["Source"]["SourceArn"])
360+
return processing_job_arns
361+
362+
def trial_components_arns(
363+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
364+
) -> List[str]:
365+
"""Get ARNs for all trial components that appear in the endpoint's lineage.
366+
367+
Returns:
368+
list of str: Trial components ARNs.
369+
"""
370+
query_filter = LineageFilter(
371+
entities=[LineageEntityEnum.TRIAL_COMPONENT]
372+
)
373+
query_result = LineageQuery(self.sagemaker_session).query(
374+
start_arns=[self.context_arn],
375+
query_filter=query_filter,
376+
direction=direction,
377+
include_edges=False,
378+
)
379+
transform_job_arns = []
380+
for vertex in query_result.vertices:
381+
trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
382+
trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
383+
TrialComponentName=trial_component_name
384+
)
385+
transform_job_arns.append(trial_component["Source"]["SourceArn"])
386+
return transform_job_arns
387+
336388
def pipeline_execution_arn(
337389
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
338390
) -> str:
@@ -351,3 +403,15 @@ def pipeline_execution_arn(
351403
return tag["Value"]
352404

353405
return None
406+
407+
408+
class ModelPackageGroup(Context):
409+
"""An Amazon SageMaker model package group context, which is part of a SageMaker lineage."""
410+
411+
def pipeline_execution_arn(self) -> str:
412+
"""Get the ARN for the pipeline execution associated with this model package group (if any).
413+
414+
Returns:
415+
str: A pipeline execution ARN.
416+
"""
417+
return self.properties["PipelineExecutionArn"]

src/sagemaker/lineage/query.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class LineageSourceEnum(Enum):
4141
MODEL_REPLACE = "ModelReplaced"
4242
TENSORBOARD = "TensorBoard"
4343
TRAINING_JOB = "TrainingJob"
44+
PROCESSING_JOB = "ProcessingJob"
45+
TRANSFORM_JOB = "TransformJob"
4446

4547

4648
class LineageQueryDirectionEnum(Enum):
@@ -87,6 +89,7 @@ def to_lineage_object(self):
8789
from sagemaker.lineage.artifact import Artifact, ModelArtifact
8890
from sagemaker.lineage.context import Context, EndpointContext
8991
from sagemaker.lineage.artifact import DatasetArtifact
92+
from sagemaker.lineage.action import Action
9093

9194
if self.lineage_entity == LineageEntityEnum.CONTEXT.value:
9295
resource_name = get_resource_name_from_arn(self.arn)
@@ -103,6 +106,8 @@ def to_lineage_object(self):
103106
return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
104107
return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session)
105108

109+
if self.lineage_entity == LineageEntityEnum.ACTION.value:
110+
return Action.load(action_name=self.arn, sagemaker_session=self._session)
106111
raise ValueError("Vertex cannot be converted to a lineage object.")
107112

108113

@@ -155,11 +160,11 @@ def __init__(
155160
def _to_request_dict(self):
156161
"""Convert the lineage filter to its API representation."""
157162
filter_request = {}
158-
if self.entities:
163+
if self.sources:
159164
filter_request["Types"] = list(
160165
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
161166
)
162-
if self.sources:
167+
if self.entities:
163168
filter_request["LineageTypes"] = list(
164169
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
165170
)

tests/integ/sagemaker/lineage/conftest.py

+58-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
from tests.integ.sagemaker.lineage.helpers import name, names
3737

3838
SLEEP_TIME_SECONDS = 1
39-
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17"
40-
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17"
39+
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline88"
40+
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint88"
41+
STATIC_MODEL_PACKAGE_GROUP_NAME = "SdkIntegTestStaticPipeline88ModelPackageGroup"
4142

4243

4344
@pytest.fixture
@@ -543,6 +544,29 @@ def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
543544
)
544545

545546

547+
@pytest.fixture
548+
def static_model_package_group_context(sagemaker_session, static_pipeline_execution_arn):
549+
550+
model_package_group_arn = get_model_package_group_arn_from_static_pipeline(sagemaker_session)
551+
552+
contexts = sagemaker_session.sagemaker_client.list_contexts(SourceUri=model_package_group_arn)[
553+
"ContextSummaries"
554+
]
555+
if len(contexts) != 1:
556+
raise (
557+
Exception(
558+
f"Got an unexpected number of Contexts for \
559+
model package group {STATIC_MODEL_PACKAGE_GROUP_NAME} from pipeline \
560+
execution {static_pipeline_execution_arn}. \
561+
Expected 1 but got {len(contexts)}"
562+
)
563+
)
564+
565+
yield context.ModelPackageGroup.load(
566+
contexts[0]["ContextName"], sagemaker_session=sagemaker_session
567+
)
568+
569+
546570
@pytest.fixture
547571
def static_model_artifact(sagemaker_session, static_pipeline_execution_arn):
548572
model_package_arn = get_model_package_arn_from_static_pipeline(
@@ -590,6 +614,31 @@ def static_dataset_artifact(static_model_artifact, sagemaker_session):
590614
)
591615

592616

617+
@pytest.fixture
618+
def static_image_artifact(static_model_artifact, sagemaker_session):
619+
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
620+
DestinationArn=static_model_artifact.artifact_arn, SourceType="Image"
621+
)
622+
if len(dataset_associations["AssociationSummaries"]) == 0:
623+
# no directly associated dataset. work backwards from the model
624+
model_associations = sagemaker_session.sagemaker_client.list_associations(
625+
DestinationArn=static_model_artifact.artifact_arn, SourceType="Model"
626+
)
627+
training_job_associations = sagemaker_session.sagemaker_client.list_associations(
628+
DestinationArn=model_associations["AssociationSummaries"][0]["SourceArn"],
629+
SourceType="SageMakerTrainingJob",
630+
)
631+
dataset_associations = sagemaker_session.sagemaker_client.list_associations(
632+
DestinationArn=training_job_associations["AssociationSummaries"][0]["SourceArn"],
633+
SourceType="Image",
634+
)
635+
636+
yield artifact.ImageArtifact.load(
637+
dataset_associations["AssociationSummaries"][0]["SourceArn"],
638+
sagemaker_session=sagemaker_session,
639+
)
640+
641+
593642
def get_endpoint_arn_from_static_pipeline(sagemaker_session):
594643
try:
595644
endpoint_arn = sagemaker_session.sagemaker_client.describe_endpoint(
@@ -604,6 +653,13 @@ def get_endpoint_arn_from_static_pipeline(sagemaker_session):
604653
raise e
605654

606655

656+
def get_model_package_group_arn_from_static_pipeline(sagemaker_session):
657+
static_model_package_group_arn = sagemaker_session.sagemaker_client.describe_model_package_group(
658+
ModelPackageGroupName=STATIC_MODEL_PACKAGE_GROUP_NAME
659+
)["ModelPackageGroupArn"]
660+
return static_model_package_group_arn
661+
662+
607663
def get_model_package_arn_from_static_pipeline(pipeline_execution_arn, sagemaker_session):
608664
# get the model package ARN from the pipeline
609665
pipeline_execution_steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps(

tests/integ/sagemaker/lineage/test_endpoint_context.py

+16
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def test_training_job_arns(
6060
assert "training-job" in arn
6161

6262

63+
def test_processing_job_arns(static_endpoint_context):
64+
processing_job_arns = static_endpoint_context.processing_job_arns()
65+
66+
assert len(processing_job_arns) > 0
67+
for arn in processing_job_arns:
68+
assert "processing-job" in arn
69+
70+
71+
def test_trial_components_arns(static_endpoint_context):
72+
trial_components_arns = static_endpoint_context.trial_components_arns()
73+
74+
assert len(trial_components_arns) > 0
75+
for arn in trial_components_arns:
76+
assert "job" in arn
77+
78+
6379
def test_pipeline_execution_arn(static_endpoint_context, static_pipeline_execution_arn):
6480
pipeline_execution_arn = static_endpoint_context.pipeline_execution_arn()
6581

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code to test SageMaker ``ModelPackageGroup``"""
14+
from __future__ import absolute_import
15+
16+
17+
def test_pipeline_execution_arn(static_model_package_group_context, static_pipeline_execution_arn):
18+
pipeline_execution_arn = static_model_package_group_context.pipeline_execution_arn()
19+
20+
assert pipeline_execution_arn == static_pipeline_execution_arn

0 commit comments

Comments
 (0)