Skip to content

feature: Add support for SageMaker lineage queries in action #2853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion src/sagemaker/lineage/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,23 @@
"""This module contains code to create and manage SageMaker ``Actions``."""
from __future__ import absolute_import

from typing import Optional, Iterator
from typing import Optional, Iterator, List
from datetime import datetime

from sagemaker import Session
from sagemaker.apiutils import _base_types
from sagemaker.lineage import _api_types, _utils
from sagemaker.lineage._api_types import ActionSource, ActionSummary
from sagemaker.lineage.artifact import Artifact
from sagemaker.lineage.context import Context

from sagemaker.lineage.query import (
LineageQuery,
LineageFilter,
LineageSourceEnum,
LineageEntityEnum,
LineageQueryDirectionEnum,
)


class Action(_base_types.Record):
Expand Down Expand Up @@ -250,3 +260,86 @@ def list(
max_results=max_results,
next_token=next_token,
)

def artifacts(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
) -> List[Artifact]:
"""Use a lineage query to retrieve all artifacts that use this action.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of Artifacts: Artifacts.
"""
query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT])
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.action_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]


class ModelPackageApprovalAction(Action):
"""An Amazon SageMaker model package approval action, which is part of a SageMaker lineage."""

def datasets(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> List[Artifact]:
"""Use a lineage query to retrieve all upstream datasets that use this action.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of Artifacts: Artifacts representing a dataset.
"""
query_filter = LineageFilter(
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
)
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.action_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]

def model_package(self):
"""Get model package from model package approval action.

Returns:
Model package.
"""
source_uri = self.source.source_uri
if source_uri is None:
return None

model_package_name = source_uri.split("/")[1]
return self.sagemaker_session.sagemaker_client.describe_model_package(
ModelPackageName=model_package_name
)

def endpoints(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
) -> List[Context]:
"""Use a lineage query to retrieve downstream endpoint contexts that use this action.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of Contexts: Contexts representing an endpoint.
"""
query_filter = LineageFilter(
entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
)
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.action_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
return [vertex.to_lineage_object() for vertex in query_result.vertices]
5 changes: 3 additions & 2 deletions src/sagemaker/lineage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class LineageSourceEnum(Enum):
MODEL_REPLACE = "ModelReplaced"
TENSORBOARD = "TensorBoard"
TRAINING_JOB = "TrainingJob"
APPROVAL = "Approval"


class LineageQueryDirectionEnum(Enum):
Expand Down Expand Up @@ -203,11 +204,11 @@ def __init__(
def _to_request_dict(self):
"""Convert the lineage filter to its API representation."""
filter_request = {}
if self.entities:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops. Good catch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1!

if self.sources:
filter_request["Types"] = list(
map(lambda x: x.value if isinstance(x, LineageSourceEnum) else x, self.sources)
)
if self.sources:
if self.entities:
filter_request["LineageTypes"] = list(
map(lambda x: x.value if isinstance(x, LineageEntityEnum) else x, self.entities)
)
Expand Down
43 changes: 43 additions & 0 deletions tests/integ/sagemaker/lineage/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
association,
artifact,
)
from sagemaker.lineage.query import (
LineageFilter,
LineageEntityEnum,
LineageSourceEnum,
LineageQuery,
LineageQueryDirectionEnum,
)
from sagemaker.model import ModelPackage
from tests.integ.test_workflow import test_end_to_end_pipeline_successful_execution
from sagemaker.workflow.pipeline import _PipelineExecution
Expand Down Expand Up @@ -514,6 +521,42 @@ def _get_static_pipeline_execution_arn(sagemaker_session):
return pipeline_execution_arn


@pytest.fixture
def static_approval_action(
sagemaker_session, static_endpoint_context, static_pipeline_execution_arn
):
query_filter = LineageFilter(
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL]
)
query_result = LineageQuery(sagemaker_session).query(
start_arns=[static_endpoint_context.context_arn],
query_filter=query_filter,
direction=LineageQueryDirectionEnum.ASCENDANTS,
include_edges=False,
)
action_name = query_result.vertices[0].arn.split("/")[1]
yield action.ModelPackageApprovalAction.load(
action_name=action_name, sagemaker_session=sagemaker_session
)


@pytest.fixture
def static_model_deployment_action(sagemaker_session, static_endpoint_context):
query_filter = LineageFilter(
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
)
query_result = LineageQuery(sagemaker_session).query(
start_arns=[static_endpoint_context.context_arn],
query_filter=query_filter,
direction=LineageQueryDirectionEnum.ASCENDANTS,
include_edges=False,
)
model_approval_actions = []
for vertex in query_result.vertices:
model_approval_actions.append(vertex.to_lineage_object())
yield model_approval_actions[0]


@pytest.fixture
def static_endpoint_context(sagemaker_session, static_pipeline_execution_arn):
endpoint_arn = get_endpoint_arn_from_static_pipeline(sagemaker_session)
Expand Down
48 changes: 48 additions & 0 deletions tests/integ/sagemaker/lineage/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from sagemaker.lineage import action
from sagemaker.lineage.query import LineageQueryDirectionEnum


def test_create_delete(action_obj):
Expand Down Expand Up @@ -117,3 +118,50 @@ def test_tags(action_obj, sagemaker_session):
# length of actual tags will be greater than 1
assert len(actual_tags) > 0
assert [actual_tags[-1]] == tags


def test_upstream_artifacts(static_model_deployment_action):
artifacts_from_query = static_model_deployment_action.artifacts(
direction=LineageQueryDirectionEnum.ASCENDANTS
)
assert len(artifacts_from_query) > 0
for artifact in artifacts_from_query:
assert "artifact" in artifact.artifact_arn


def test_downstream_artifacts(static_approval_action):
artifacts_from_query = static_approval_action.artifacts(
direction=LineageQueryDirectionEnum.DESCENDANTS
)
assert len(artifacts_from_query) > 0
for artifact in artifacts_from_query:
assert "artifact" in artifact.artifact_arn


def test_datasets(static_approval_action, static_dataset_artifact, sagemaker_session):

sagemaker_session.sagemaker_client.add_association(
SourceArn=static_dataset_artifact.artifact_arn,
DestinationArn=static_approval_action.action_arn,
AssociationType="ContributedTo",
)
time.sleep(3)
artifacts_from_query = static_approval_action.datasets()

assert len(artifacts_from_query) > 0
for artifact in artifacts_from_query:
assert "artifact" in artifact.artifact_arn
assert artifact.artifact_type == "DataSet"

sagemaker_session.sagemaker_client.delete_association(
SourceArn=static_dataset_artifact.artifact_arn,
DestinationArn=static_approval_action.action_arn,
)


def test_endpoints(static_approval_action):
endpoint_contexts_from_query = static_approval_action.endpoints()
assert len(endpoint_contexts_from_query) > 0
for endpoint in endpoint_contexts_from_query:
assert endpoint.context_type == "Endpoint"
assert "endpoint" in endpoint.context_arn
21 changes: 21 additions & 0 deletions tests/unit/sagemaker/lineage/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest.mock

from sagemaker.lineage import action, _api_types
from sagemaker.lineage._api_types import ActionSource


def test_create(sagemaker_session):
Expand Down Expand Up @@ -333,3 +334,23 @@ def test_create_delete_with_association(sagemaker_session):
delete_with_association_expected_calls
== sagemaker_session.sagemaker_client.delete_association.mock_calls
)


def test_model_package(sagemaker_session):
obj = action.ModelPackageApprovalAction(
sagemaker_session,
action_name="abcd-aws-model-package",
source=ActionSource(
source_uri="arn:aws:sagemaker:us-west-2:123456789012:model-package/pipeline88modelpackage/1",
source_type="ARN",
),
status="updated-status",
properties={"k1": "v1"},
properties_to_remove=["k2"],
)
sagemaker_session.sagemaker_client.describe_model_package.return_value = {}
obj.model_package()

sagemaker_session.sagemaker_client.describe_model_package.assert_called_with(
ModelPackageName="pipeline88modelpackage",
)