Skip to content

Commit b82fb8a

Browse files
authored
feature: Add support for SageMaker lineage queries in action (#2853)
1 parent 9d259b3 commit b82fb8a

File tree

5 files changed

+209
-3
lines changed

5 files changed

+209
-3
lines changed

src/sagemaker/lineage/action.py

+94-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,23 @@
1313
"""This module contains code to create and manage SageMaker ``Actions``."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional, Iterator
16+
from typing import Optional, Iterator, List
1717
from datetime import datetime
1818

1919
from sagemaker import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
23+
from sagemaker.lineage.artifact import Artifact
24+
from sagemaker.lineage.context import Context
25+
26+
from sagemaker.lineage.query import (
27+
LineageQuery,
28+
LineageFilter,
29+
LineageSourceEnum,
30+
LineageEntityEnum,
31+
LineageQueryDirectionEnum,
32+
)
2333

2434

2535
class Action(_base_types.Record):
@@ -250,3 +260,86 @@ def list(
250260
max_results=max_results,
251261
next_token=next_token,
252262
)
263+
264+
def artifacts(
265+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
266+
) -> List[Artifact]:
267+
"""Use a lineage query to retrieve all artifacts that use this action.
268+
269+
Args:
270+
direction (LineageQueryDirectionEnum, optional): The query direction.
271+
272+
Returns:
273+
list of Artifacts: Artifacts.
274+
"""
275+
query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT])
276+
query_result = LineageQuery(self.sagemaker_session).query(
277+
start_arns=[self.action_arn],
278+
query_filter=query_filter,
279+
direction=direction,
280+
include_edges=False,
281+
)
282+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
283+
284+
285+
class ModelPackageApprovalAction(Action):
286+
"""An Amazon SageMaker model package approval action, which is part of a SageMaker lineage."""
287+
288+
def datasets(
289+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
290+
) -> List[Artifact]:
291+
"""Use a lineage query to retrieve all upstream datasets that use this action.
292+
293+
Args:
294+
direction (LineageQueryDirectionEnum, optional): The query direction.
295+
296+
Returns:
297+
list of Artifacts: Artifacts representing a dataset.
298+
"""
299+
query_filter = LineageFilter(
300+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
301+
)
302+
query_result = LineageQuery(self.sagemaker_session).query(
303+
start_arns=[self.action_arn],
304+
query_filter=query_filter,
305+
direction=direction,
306+
include_edges=False,
307+
)
308+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
309+
310+
def model_package(self):
311+
"""Get model package from model package approval action.
312+
313+
Returns:
314+
Model package.
315+
"""
316+
source_uri = self.source.source_uri
317+
if source_uri is None:
318+
return None
319+
320+
model_package_name = source_uri.split("/")[1]
321+
return self.sagemaker_session.sagemaker_client.describe_model_package(
322+
ModelPackageName=model_package_name
323+
)
324+
325+
def endpoints(
326+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
327+
) -> List[Context]:
328+
"""Use a lineage query to retrieve downstream endpoint contexts that use this action.
329+
330+
Args:
331+
direction (LineageQueryDirectionEnum, optional): The query direction.
332+
333+
Returns:
334+
list of Contexts: Contexts representing an endpoint.
335+
"""
336+
query_filter = LineageFilter(
337+
entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
338+
)
339+
query_result = LineageQuery(self.sagemaker_session).query(
340+
start_arns=[self.action_arn],
341+
query_filter=query_filter,
342+
direction=direction,
343+
include_edges=False,
344+
)
345+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/query.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class LineageSourceEnum(Enum):
4343
MODEL_REPLACE = "ModelReplaced"
4444
TENSORBOARD = "TensorBoard"
4545
TRAINING_JOB = "TrainingJob"
46+
APPROVAL = "Approval"
4647

4748

4849
class LineageQueryDirectionEnum(Enum):
@@ -203,11 +204,11 @@ def __init__(
203204
def _to_request_dict(self):
204205
"""Convert the lineage filter to its API representation."""
205206
filter_request = {}
206-
if self.entities:
207+
if self.sources:
207208
filter_request["Types"] = list(
208209
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
209210
)
210-
if self.sources:
211+
if self.entities:
211212
filter_request["LineageTypes"] = list(
212213
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
213214
)

tests/integ/sagemaker/lineage/conftest.py

+43
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
association,
2626
artifact,
2727
)
28+
from sagemaker.lineage.query import (
29+
LineageFilter,
30+
LineageEntityEnum,
31+
LineageSourceEnum,
32+
LineageQuery,
33+
LineageQueryDirectionEnum,
34+
)
2835
from sagemaker.model import ModelPackage
2936
from tests.integ.test_workflow import test_end_to_end_pipeline_successful_execution
3037
from sagemaker.workflow.pipeline import _PipelineExecution
@@ -514,6 +521,42 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
514521
return pipeline_execution_arn
515522

516523

524+
@pytest.fixture
525+
def static_approval_action(
526+
sagemaker_session, static_endpoint_context, static_pipeline_execution_arn
527+
):
528+
query_filter = LineageFilter(
529+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL]
530+
)
531+
query_result = LineageQuery(sagemaker_session).query(
532+
start_arns=[static_endpoint_context.context_arn],
533+
query_filter=query_filter,
534+
direction=LineageQueryDirectionEnum.ASCENDANTS,
535+
include_edges=False,
536+
)
537+
action_name = query_result.vertices[0].arn.split("/")[1]
538+
yield action.ModelPackageApprovalAction.load(
539+
action_name=action_name, sagemaker_session=sagemaker_session
540+
)
541+
542+
543+
@pytest.fixture
544+
def static_model_deployment_action(sagemaker_session, static_endpoint_context):
545+
query_filter = LineageFilter(
546+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
547+
)
548+
query_result = LineageQuery(sagemaker_session).query(
549+
start_arns=[static_endpoint_context.context_arn],
550+
query_filter=query_filter,
551+
direction=LineageQueryDirectionEnum.ASCENDANTS,
552+
include_edges=False,
553+
)
554+
model_approval_actions = []
555+
for vertex in query_result.vertices:
556+
model_approval_actions.append(vertex.to_lineage_object())
557+
yield model_approval_actions[0]
558+
559+
517560
@pytest.fixture
518561
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
519562
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)

tests/integ/sagemaker/lineage/test_action.py

+48
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
from sagemaker.lineage import action
23+
from sagemaker.lineage.query import LineageQueryDirectionEnum
2324

2425

2526
def test_create_delete(action_obj):
@@ -117,3 +118,50 @@ def test_tags(action_obj, sagemaker_session):
117118
# length of actual tags will be greater than 1
118119
assert len(actual_tags) > 0
119120
assert [actual_tags[-1]] == tags
121+
122+
123+
def test_upstream_artifacts(static_model_deployment_action):
124+
artifacts_from_query = static_model_deployment_action.artifacts(
125+
direction=LineageQueryDirectionEnum.ASCENDANTS
126+
)
127+
assert len(artifacts_from_query) > 0
128+
for artifact in artifacts_from_query:
129+
assert "artifact" in artifact.artifact_arn
130+
131+
132+
def test_downstream_artifacts(static_approval_action):
133+
artifacts_from_query = static_approval_action.artifacts(
134+
direction=LineageQueryDirectionEnum.DESCENDANTS
135+
)
136+
assert len(artifacts_from_query) > 0
137+
for artifact in artifacts_from_query:
138+
assert "artifact" in artifact.artifact_arn
139+
140+
141+
def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_session):
142+
143+
sagemaker_session.sagemaker_client.add_association(
144+
SourceArn=static_dataset_artifact.artifact_arn,
145+
DestinationArn=static_approval_action.action_arn,
146+
AssociationType="ContributedTo",
147+
)
148+
time.sleep(3)
149+
artifacts_from_query = static_approval_action.datasets()
150+
151+
assert len(artifacts_from_query) > 0
152+
for artifact in artifacts_from_query:
153+
assert "artifact" in artifact.artifact_arn
154+
assert artifact.artifact_type == "DataSet"
155+
156+
sagemaker_session.sagemaker_client.delete_association(
157+
SourceArn=static_dataset_artifact.artifact_arn,
158+
DestinationArn=static_approval_action.action_arn,
159+
)
160+
161+
162+
def test_endpoints(static_approval_action):
163+
endpoint_contexts_from_query = static_approval_action.endpoints()
164+
assert len(endpoint_contexts_from_query) > 0
165+
for endpoint in endpoint_contexts_from_query:
166+
assert endpoint.context_type == "Endpoint"
167+
assert "endpoint" in endpoint.context_arn

tests/unit/sagemaker/lineage/test_action.py

+21
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest.mock
1717

1818
from sagemaker.lineage import action, _api_types
19+
from sagemaker.lineage._api_types import ActionSource
1920

2021

2122
def test_create(sagemaker_session):
@@ -333,3 +334,23 @@ def test_create_delete_with_association(sagemaker_session):
333334
delete_with_association_expected_calls
334335
== sagemaker_session.sagemaker_client.delete_association.mock_calls
335336
)
337+
338+
339+
def test_model_package(sagemaker_session):
340+
obj = action.ModelPackageApprovalAction(
341+
sagemaker_session,
342+
action_name="abcd-aws-model-package",
343+
source=ActionSource(
344+
source_uri="arn:aws:sagemaker:us-west-2:123456789012:model-package/pipeline88modelpackage/1",
345+
source_type="ARN",
346+
),
347+
status="updated-status",
348+
properties={"k1": "v1"},
349+
properties_to_remove=["k2"],
350+
)
351+
sagemaker_session.sagemaker_client.describe_model_package.return_value = {}
352+
obj.model_package()
353+
354+
sagemaker_session.sagemaker_client.describe_model_package.assert_called_with(
355+
ModelPackageName="pipeline88modelpackage",
356+
)

0 commit comments

Comments
 (0)