Skip to content

Commit 897a752

Browse files
committed
feature: Add support for SageMaker lineage queries in action
1 parent fd7a335 commit 897a752

File tree

5 files changed

+229
-3
lines changed

5 files changed

+229
-3
lines changed

src/sagemaker/lineage/action.py

+119-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@
1313
"""This module contains code to create and manage SageMaker ``Actions``."""
1414
from __future__ import absolute_import
1515

16-
from typing import Optional, Iterator
16+
from typing import Optional, Iterator, List
1717
from datetime import datetime
1818

1919
from sagemaker import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
23+
from sagemaker.lineage.artifact import Artifact
24+
25+
from sagemaker.lineage.query import (
26+
LineageQuery,
27+
LineageFilter,
28+
LineageSourceEnum,
29+
LineageEntityEnum,
30+
LineageQueryDirectionEnum,
31+
)
2332

2433

2534
class Action(_base_types.Record):
@@ -250,3 +259,112 @@ def list(
250259
max_results=max_results,
251260
next_token=next_token,
252261
)
262+
263+
def upstream_artifacts(
264+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
265+
) -> List[Artifact]:
266+
"""Use a lineage query to retrieve all upstream artifacts that use this action.
267+
268+
Args:
269+
direction (LineageQueryDirectionEnum, optional): The query direction.
270+
271+
Returns:
272+
list of Artifacts: Artifacts.
273+
"""
274+
return self._artifacts(direction=direction)
275+
276+
def downstream_artifacts(
277+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
278+
) -> List[Artifact]:
279+
"""Use a lineage query to retrieve all downstream artifacts that use this action.
280+
281+
Args:
282+
direction (LineageQueryDirectionEnum, optional): The query direction.
283+
284+
Returns:
285+
list of Artifacts: Artifacts.
286+
"""
287+
return self._artifacts(direction=direction)
288+
289+
def _artifacts(
290+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
291+
) -> List[Artifact]:
292+
"""Use a lineage query to retrieve all artifacts that use this action.
293+
294+
Args:
295+
direction (LineageQueryDirectionEnum, optional): The query direction.
296+
297+
Returns:
298+
list of Artifacts: Artifacts.
299+
"""
300+
query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT])
301+
query_result = LineageQuery(self.sagemaker_session).query(
302+
start_arns=[self.action_arn],
303+
query_filter=query_filter,
304+
direction=direction,
305+
include_edges=False,
306+
)
307+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
308+
309+
310+
class ModelPackageApprovalAction(Action):
311+
"""An Amazon SageMaker model package approval action, which is part of a SageMaker lineage."""
312+
313+
def upstream_datasets(
314+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
315+
) -> List[Artifact]:
316+
"""Use a lineage query to retrieve all upstream datasets that use this action.
317+
318+
Args:
319+
direction (LineageQueryDirectionEnum, optional): The query direction.
320+
321+
Returns:
322+
list of Artifacts: Artifacts representing a dataset.
323+
"""
324+
query_filter = LineageFilter(
325+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
326+
)
327+
query_result = LineageQuery(self.sagemaker_session).query(
328+
start_arns=[self.action_arn],
329+
query_filter=query_filter,
330+
direction=direction,
331+
include_edges=False,
332+
)
333+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
334+
335+
def model_package(self):
336+
"""Get model package from model package approval action.
337+
338+
Returns:
339+
Model package.
340+
"""
341+
source_uri = self.source.source_uri
342+
if source_uri is None:
343+
return None
344+
345+
model_package_name = source_uri.split("/")[1]
346+
return self.sagemaker_session.sagemaker_client.describe_model_package(
347+
ModelPackageName=model_package_name
348+
)
349+
350+
def endpoints(
351+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
352+
):
353+
"""Use a lineage query to retrieve downstream endpoint contexts that use this action.
354+
355+
Args:
356+
direction (LineageQueryDirectionEnum, optional): The query direction.
357+
358+
Returns:
359+
list of Contexts: Contexts representing an endpoint.
360+
"""
361+
query_filter = LineageFilter(
362+
entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
363+
)
364+
query_result = LineageQuery(self.sagemaker_session).query(
365+
start_arns=[self.action_arn],
366+
query_filter=query_filter,
367+
direction=direction,
368+
include_edges=False,
369+
)
370+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/query.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class LineageSourceEnum(Enum):
4343
MODEL_REPLACE = "ModelReplaced"
4444
TENSORBOARD = "TensorBoard"
4545
TRAINING_JOB = "TrainingJob"
46+
APPROVAL = "Approval"
4647

4748

4849
class LineageQueryDirectionEnum(Enum):
@@ -203,11 +204,11 @@ def __init__(
203204
def _to_request_dict(self):
204205
"""Convert the lineage filter to its API representation."""
205206
filter_request = {}
206-
if self.entities:
207+
if self.sources:
207208
filter_request["Types"] = list(
208209
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
209210
)
210-
if self.sources:
211+
if self.entities:
211212
filter_request["LineageTypes"] = list(
212213
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
213214
)

tests/integ/sagemaker/lineage/conftest.py

+43
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@
2525
association,
2626
artifact,
2727
)
28+
from sagemaker.lineage.query import (
29+
LineageFilter,
30+
LineageEntityEnum,
31+
LineageSourceEnum,
32+
LineageQuery,
33+
LineageQueryDirectionEnum,
34+
)
2835
from sagemaker.model import ModelPackage
2936
from tests.integ.test_workflow import test_end_to_end_pipeline_successful_execution
3037
from sagemaker.workflow.pipeline import _PipelineExecution
@@ -514,6 +521,42 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
514521
return pipeline_execution_arn
515522

516523

524+
@pytest.fixture
525+
def static_approval_action(
526+
sagemaker_session, static_endpoint_context, static_pipeline_execution_arn
527+
):
528+
query_filter = LineageFilter(
529+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL]
530+
)
531+
query_result = LineageQuery(sagemaker_session).query(
532+
start_arns=[static_endpoint_context.context_arn],
533+
query_filter=query_filter,
534+
direction=LineageQueryDirectionEnum.ASCENDANTS,
535+
include_edges=False,
536+
)
537+
action_name = query_result.vertices[0].arn.split("/")[1]
538+
yield action.ModelPackageApprovalAction.load(
539+
action_name=action_name, sagemaker_session=sagemaker_session
540+
)
541+
542+
543+
@pytest.fixture
544+
def static_model_deployment_action(sagemaker_session, static_endpoint_context):
545+
query_filter = LineageFilter(
546+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
547+
)
548+
query_result = LineageQuery(sagemaker_session).query(
549+
start_arns=[static_endpoint_context.context_arn],
550+
query_filter=query_filter,
551+
direction=LineageQueryDirectionEnum.ASCENDANTS,
552+
include_edges=False,
553+
)
554+
model_approval_actions = []
555+
for vertex in query_result.vertices:
556+
model_approval_actions.append(vertex.to_lineage_object())
557+
yield model_approval_actions[0]
558+
559+
517560
@pytest.fixture
518561
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
519562
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)

tests/integ/sagemaker/lineage/test_action.py

+43
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,46 @@ def test_tags(action_obj, sagemaker_session):
117117
# length of actual tags will be greater than 1
118118
assert len(actual_tags) > 0
119119
assert [actual_tags[-1]] == tags
120+
121+
122+
def test_upstream_artifacts(static_model_deployment_action):
123+
artifacts_from_query = static_model_deployment_action.upstream_artifacts()
124+
assert len(artifacts_from_query) > 0
125+
for artifact in artifacts_from_query:
126+
assert "artifact" in artifact.artifact_arn
127+
128+
129+
def test_downstream_artifact(static_approval_action):
130+
artifacts_from_query = static_approval_action.downstream_artifacts()
131+
assert len(artifacts_from_query) > 0
132+
for artifact in artifacts_from_query:
133+
assert "artifact" in artifact.artifact_arn
134+
135+
136+
def test_dataset(static_approval_action, static_dataset_artifact, sagemaker_session):
137+
138+
sagemaker_session.sagemaker_client.add_association(
139+
SourceArn=static_dataset_artifact.artifact_arn,
140+
DestinationArn=static_approval_action.action_arn,
141+
AssociationType="ContributedTo",
142+
)
143+
time.sleep(3)
144+
artifacts_from_query = static_approval_action.upstream_datasets()
145+
146+
assert len(artifacts_from_query) > 0
147+
for artifact in artifacts_from_query:
148+
assert "artifact" in artifact.artifact_arn
149+
assert artifact.artifact_type == "DataSet"
150+
151+
sagemaker_session.sagemaker_client.delete_association(
152+
SourceArn=static_dataset_artifact.artifact_arn,
153+
DestinationArn=static_approval_action.action_arn,
154+
)
155+
156+
157+
def test_endpoints(static_approval_action):
158+
endpoint_contexts_from_query = static_approval_action.endpoints()
159+
assert len(endpoint_contexts_from_query) > 0
160+
for endpoint in endpoint_contexts_from_query:
161+
assert endpoint.context_type == "Endpoint"
162+
assert "endpoint" in endpoint.context_arn

tests/unit/sagemaker/lineage/test_action.py

+21
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest.mock
1717

1818
from sagemaker.lineage import action, _api_types
19+
from sagemaker.lineage._api_types import ActionSource
1920

2021

2122
def test_create(sagemaker_session):
@@ -333,3 +334,23 @@ def test_create_delete_with_association(sagemaker_session):
333334
delete_with_association_expected_calls
334335
== sagemaker_session.sagemaker_client.delete_association.mock_calls
335336
)
337+
338+
339+
def test_model_package(sagemaker_session):
340+
obj = action.ModelPackageApprovalAction(
341+
sagemaker_session,
342+
action_name="abcd-aws-model-package",
343+
source=ActionSource(
344+
source_uri="arn:aws:sagemaker:us-west-2:123456789012:model-package/pipeline88modelpackage/1",
345+
source_type="ARN",
346+
),
347+
status="updated-status",
348+
properties={"k1": "v1"},
349+
properties_to_remove=["k2"],
350+
)
351+
sagemaker_session.sagemaker_client.describe_model_package.return_value = {}
352+
obj.model_package()
353+
354+
sagemaker_session.sagemaker_client.describe_model_package.assert_called_with(
355+
ModelPackageName="pipeline88modelpackage",
356+
)

0 commit comments

Comments
 (0)