Skip to content

Commit d87de58

Browse files
committed
feature: Add support for SageMaker lineage queries in action
1 parent fd7a335 commit d87de58

File tree

5 files changed

+208
-3
lines changed

5 files changed

+208
-3
lines changed

src/sagemaker/lineage/action.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@
1313
"""This module contains code to create and manage SageMaker ``Actions``."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional, Iterator
16+
from typing import Optional, Iterator, List
1717
from datetime import datetime
1818

1919
from sagemaker import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
23+
from sagemaker.lineage.artifact import Artifact
24+
25+
from sagemaker.lineage.query import (
26+
LineageQuery,
27+
LineageFilter,
28+
LineageSourceEnum,
29+
LineageEntityEnum,
30+
LineageQueryDirectionEnum,
31+
)
2332

2433

2534
class Action(_base_types.Record):
@@ -250,3 +259,86 @@ def list(
250259
max_results=max_results,
251260
next_token=next_token,
252261
)
262+
263+
def artifacts(
264+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
265+
) -> List[Artifact]:
266+
"""Use a lineage query to retrieve all artifacts that use this action.
267+
268+
Args:
269+
direction (LineageQueryDirectionEnum, optional): The query direction.
270+
271+
Returns:
272+
list of Artifacts: Artifacts.
273+
"""
274+
query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT])
275+
query_result = LineageQuery(self.sagemaker_session).query(
276+
start_arns=[self.action_arn],
277+
query_filter=query_filter,
278+
direction=direction,
279+
include_edges=False,
280+
)
281+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
282+
283+
284+
class ModelPackageApprovalAction(Action):
285+
"""An Amazon SageMaker model package approval action, which is part of a SageMaker lineage."""
286+
287+
def datasets(
288+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
289+
) -> List[Artifact]:
290+
"""Use a lineage query to retrieve all upstream datasets that use this action.
291+
292+
Args:
293+
direction (LineageQueryDirectionEnum, optional): The query direction.
294+
295+
Returns:
296+
list of Artifacts: Artifacts representing a dataset.
297+
"""
298+
query_filter = LineageFilter(
299+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
300+
)
301+
query_result = LineageQuery(self.sagemaker_session).query(
302+
start_arns=[self.action_arn],
303+
query_filter=query_filter,
304+
direction=direction,
305+
include_edges=False,
306+
)
307+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
308+
309+
def model_package(self):
310+
"""Get model package from model package approval action.
311+
312+
Returns:
313+
Model package.
314+
"""
315+
source_uri = self.source.source_uri
316+
if source_uri is None:
317+
return None
318+
319+
model_package_name = source_uri.split("/")[1]
320+
return self.sagemaker_session.sagemaker_client.describe_model_package(
321+
ModelPackageName=model_package_name
322+
)
323+
324+
def endpoints(
325+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
326+
):
327+
"""Use a lineage query to retrieve downstream endpoint contexts that use this action.
328+
329+
Args:
330+
direction (LineageQueryDirectionEnum, optional): The query direction.
331+
332+
Returns:
333+
list of Contexts: Contexts representing an endpoint.
334+
"""
335+
query_filter = LineageFilter(
336+
entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
337+
)
338+
query_result = LineageQuery(self.sagemaker_session).query(
339+
start_arns=[self.action_arn],
340+
query_filter=query_filter,
341+
direction=direction,
342+
include_edges=False,
343+
)
344+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/query.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class LineageSourceEnum(Enum):
4343
MODEL_REPLACE = "ModelReplaced"
4444
TENSORBOARD = "TensorBoard"
4545
TRAINING_JOB = "TrainingJob"
46+
APPROVAL = "Approval"
4647

4748

4849
class LineageQueryDirectionEnum(Enum):
@@ -203,11 +204,11 @@ def __init__(
203204
def _to_request_dict(self):
204205
"""Convert the lineage filter to its API representation."""
205206
filter_request = {}
206-
if self.entities:
207+
if self.sources:
207208
filter_request["Types"] = list(
208209
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
209210
)
210-
if self.sources:
211+
if self.entities:
211212
filter_request["LineageTypes"] = list(
212213
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
213214
)

tests/integ/sagemaker/lineage/conftest.py

+43
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
association,
2626
artifact,
2727
)
28+
from sagemaker.lineage.query import (
29+
LineageFilter,
30+
LineageEntityEnum,
31+
LineageSourceEnum,
32+
LineageQuery,
33+
LineageQueryDirectionEnum,
34+
)
2835
from sagemaker.model import ModelPackage
2936
from tests.integ.test_workflow import test_end_to_end_pipeline_successful_execution
3037
from sagemaker.workflow.pipeline import _PipelineExecution
@@ -514,6 +521,42 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
514521
return pipeline_execution_arn
515522

516523

524+
@pytest.fixture
525+
def static_approval_action(
526+
sagemaker_session, static_endpoint_context, static_pipeline_execution_arn
527+
):
528+
query_filter = LineageFilter(
529+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL]
530+
)
531+
query_result = LineageQuery(sagemaker_session).query(
532+
start_arns=[static_endpoint_context.context_arn],
533+
query_filter=query_filter,
534+
direction=LineageQueryDirectionEnum.ASCENDANTS,
535+
include_edges=False,
536+
)
537+
action_name = query_result.vertices[0].arn.split("/")[1]
538+
yield action.ModelPackageApprovalAction.load(
539+
action_name=action_name, sagemaker_session=sagemaker_session
540+
)
541+
542+
543+
@pytest.fixture
544+
def static_model_deployment_action(sagemaker_session, static_endpoint_context):
545+
query_filter = LineageFilter(
546+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
547+
)
548+
query_result = LineageQuery(sagemaker_session).query(
549+
start_arns=[static_endpoint_context.context_arn],
550+
query_filter=query_filter,
551+
direction=LineageQueryDirectionEnum.ASCENDANTS,
552+
include_edges=False,
553+
)
554+
model_approval_actions = []
555+
for vertex in query_result.vertices:
556+
model_approval_actions.append(vertex.to_lineage_object())
557+
yield model_approval_actions[0]
558+
559+
517560
@pytest.fixture
518561
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
519562
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)

tests/integ/sagemaker/lineage/test_action.py

+48
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
from sagemaker.lineage import action
23+
from sagemaker.lineage.query import LineageQueryDirectionEnum
2324

2425

2526
def test_create_delete(action_obj):
@@ -117,3 +118,50 @@ def test_tags(action_obj, sagemaker_session):
117118
# length of actual tags will be greater than 1
118119
assert len(actual_tags) > 0
119120
assert [actual_tags[-1]] == tags
121+
122+
123+
def test_upstream_artifacts(static_model_deployment_action):
124+
artifacts_from_query = static_model_deployment_action.artifacts(
125+
direction=LineageQueryDirectionEnum.ASCENDANTS
126+
)
127+
assert len(artifacts_from_query) > 0
128+
for artifact in artifacts_from_query:
129+
assert "artifact" in artifact.artifact_arn
130+
131+
132+
def test_downstream_artifacts(static_approval_action):
133+
artifacts_from_query = static_approval_action.artifacts(
134+
direction=LineageQueryDirectionEnum.DESCENDANTS
135+
)
136+
assert len(artifacts_from_query) > 0
137+
for artifact in artifacts_from_query:
138+
assert "artifact" in artifact.artifact_arn
139+
140+
141+
def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_session):
142+
143+
sagemaker_session.sagemaker_client.add_association(
144+
SourceArn=static_dataset_artifact.artifact_arn,
145+
DestinationArn=static_approval_action.action_arn,
146+
AssociationType="ContributedTo",
147+
)
148+
time.sleep(3)
149+
artifacts_from_query = static_approval_action.datasets()
150+
151+
assert len(artifacts_from_query) > 0
152+
for artifact in artifacts_from_query:
153+
assert "artifact" in artifact.artifact_arn
154+
assert artifact.artifact_type == "DataSet"
155+
156+
sagemaker_session.sagemaker_client.delete_association(
157+
SourceArn=static_dataset_artifact.artifact_arn,
158+
DestinationArn=static_approval_action.action_arn,
159+
)
160+
161+
162+
def test_endpoints(static_approval_action):
163+
endpoint_contexts_from_query = static_approval_action.endpoints()
164+
assert len(endpoint_contexts_from_query) > 0
165+
for endpoint in endpoint_contexts_from_query:
166+
assert endpoint.context_type == "Endpoint"
167+
assert "endpoint" in endpoint.context_arn

tests/unit/sagemaker/lineage/test_action.py

+21
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest.mock
1717

1818
from sagemaker.lineage import action, _api_types
19+
from sagemaker.lineage._api_types import ActionSource
1920

2021

2122
def test_create(sagemaker_session):
@@ -333,3 +334,23 @@ def test_create_delete_with_association(sagemaker_session):
333334
delete_with_association_expected_calls
334335
== sagemaker_session.sagemaker_client.delete_association.mock_calls
335336
)
337+
338+
339+
def test_model_package(sagemaker_session):
340+
obj = action.ModelPackageApprovalAction(
341+
sagemaker_session,
342+
action_name="abcd-aws-model-package",
343+
source=ActionSource(
344+
source_uri="arn:aws:sagemaker:us-west-2:123456789012:model-package/pipeline88modelpackage/1",
345+
source_type="ARN",
346+
),
347+
status="updated-status",
348+
properties={"k1": "v1"},
349+
properties_to_remove=["k2"],
350+
)
351+
sagemaker_session.sagemaker_client.describe_model_package.return_value = {}
352+
obj.model_package()
353+
354+
sagemaker_session.sagemaker_client.describe_model_package.assert_called_with(
355+
ModelPackageName="pipeline88modelpackage",
356+
)

0 commit comments

Comments
 (0)