Skip to content

feature: Add models_v2 under lineage context #2800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/sagemaker/lineage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,55 @@ def models(self) -> List[association.Association]:
]
return model_list

def models_v2(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
) -> List[Artifact]:
"""Get artifacts representing models from the context lineage by querying lineage data.

Args:
direction (LineageQueryDirectionEnum, optional): The query direction.

Returns:
list of Artifacts: Artifacts representing a model.
"""
# Firstly query out the model_deployment vertices
query_filter = LineageFilter(
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
)
model_deployment_query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[self.context_arn],
query_filter=query_filter,
direction=direction,
include_edges=False,
)
if not model_deployment_query_result:
return []

model_deployment_vertices: [] = model_deployment_query_result.vertices

# Secondary query model based on model deployment
model_vertices = []
for vertex in model_deployment_vertices:
query_result = LineageQuery(self.sagemaker_session).query(
start_arns=[vertex.arn],
query_filter=LineageFilter(
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]
),
direction=LineageQueryDirectionEnum.DESCENDANTS,
include_edges=False,
)
model_vertices.extend(query_result.vertices)

if not model_vertices:
return []

model_artifacts = []
for vertex in model_vertices:
lineage_object = vertex.to_lineage_object()
model_artifacts.append(lineage_object)

return model_artifacts

def dataset_artifacts(
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
) -> List[Artifact]:
Expand Down
14 changes: 7 additions & 7 deletions tests/integ/sagemaker/lineage/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from tests.integ.sagemaker.lineage.helpers import name, names

SLEEP_TIME_SECONDS = 1
SLEEP_TIME_TWO_SECONDS = 2
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17"
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17"

Expand Down Expand Up @@ -360,12 +361,10 @@ def endpoint_context_obj(sagemaker_session):

@pytest.fixture
def model_obj(sagemaker_session):
model = context.Context.create(
context_name=name(),
model = artifact.Artifact.create(
artifact_name=name(),
artifact_type="Model",
source_uri="bar1",
source_type="test-source-type1",
context_type="Model",
description="test-description",
properties={"k1": "v1"},
sagemaker_session=sagemaker_session,
)
Expand Down Expand Up @@ -417,11 +416,12 @@ def endpoint_context_associate_with_model(sagemaker_session, endpoint_action_obj

association.Association.create(
source_arn=endpoint_action_obj.action_arn,
destination_arn=model_obj.context_arn,
destination_arn=model_obj.artifact_arn,
sagemaker_session=sagemaker_session,
)
yield obj
time.sleep(SLEEP_TIME_SECONDS)
# sleep 2 seconds since take longer for lineage injection
time.sleep(SLEEP_TIME_TWO_SECONDS)
obj.delete(disassociate=True)


Expand Down
16 changes: 15 additions & 1 deletion tests/integ/sagemaker/lineage/test_endpoint_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,31 @@
# language governing permissions and limitations under the License.
"""This module contains code to test SageMaker ``Contexts``"""
from __future__ import absolute_import
import time

SLEEP_TIME_ONE_SECONDS = 1


def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj):
model_list = endpoint_context_associate_with_model.models()
for model in model_list:
assert model.source_arn == endpoint_action_obj.action_arn
assert model.destination_arn == model_obj.context_arn
assert model.destination_arn == model_obj.artifact_arn
assert model.source_type == "ModelDeployment"
assert model.destination_type == "Model"


def test_model_v2(endpoint_context_associate_with_model, model_obj, sagemaker_session):
time.sleep(SLEEP_TIME_ONE_SECONDS)
model_list = endpoint_context_associate_with_model.models_v2()
assert len(model_list) == 1
for model in model_list:
assert model.artifact_arn == model_obj.artifact_arn
assert model.artifact_name == model_obj.artifact_name
assert model.artifact_type == "Model"
assert model.properties == model_obj.properties


def test_dataset_artifacts(static_endpoint_context):
artifacts_from_query = static_endpoint_context.dataset_artifacts()

Expand Down
112 changes: 112 additions & 0 deletions tests/unit/sagemaker/lineage/test_endpoint_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import unittest.mock

from sagemaker.lineage import context, _api_types
from sagemaker.lineage._api_types import ArtifactSource
from sagemaker.lineage.artifact import ModelArtifact


def test_models(sagemaker_session):
Expand Down Expand Up @@ -75,3 +77,113 @@ def test_models(sagemaker_session):
)
]
assert expected_model_list == model_list


def test_models_v2(sagemaker_session):
arn1 = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-integ-3b05f017-0d87-4c37"

obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=arn1)

sagemaker_session.sagemaker_client.query_lineage.return_value = {
"Vertices": [
{"Arn": arn1, "Type": "Model", "LineageType": "Artifact"},
],
"Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}],
}

sagemaker_session.sagemaker_client.describe_context.return_value = {
"ContextName": "MyContext",
"ContextArn": arn1,
"Source": {
"SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:endpoint/myendpoint",
"SourceType": "ARN",
"SourceId": "Thu Dec 17 17:16:24 UTC 2020",
},
"ContextType": "Endpoint",
"Properties": {
"PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\
pipeline/mypipeline/execution/0irnteql64d0",
"PipelineStepName": "MyStep",
"Status": "Completed",
},
"CreationTime": 1608225384.0,
"CreatedBy": {},
"LastModifiedTime": 1608225384.0,
"LastModifiedBy": {},
}

sagemaker_session.sagemaker_client.describe_artifact.return_value = {
"ArtifactName": "MyArtifact",
"ArtifactArn": arn1,
"Source": {
"SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel",
"SourceType": "ARN",
"SourceId": "Thu Dec 17 17:16:24 UTC 2020",
},
"ArtifactType": "Model",
"Properties": {
"PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\
pipeline/mypipeline/execution/0irnteql64d0",
"PipelineStepName": "MyStep",
"Status": "Completed",
},
"CreationTime": 1608225384.0,
"CreatedBy": {},
"LastModifiedTime": 1608225384.0,
"LastModifiedBy": {},
}

model_list = obj.models_v2()

expected_calls = [
unittest.mock.call(
Direction="Descendants",
Filters={"Types": ["ModelDeployment"], "LineageTypes": ["Action"]},
IncludeEdges=False,
MaxDepth=10,
StartArns=[arn1],
),
unittest.mock.call(
Direction="Descendants",
Filters={"Types": ["Model"], "LineageTypes": ["Artifact"]},
IncludeEdges=False,
MaxDepth=10,
StartArns=[arn1],
),
]
assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls

expected_model_list = [
ModelArtifact(
artifact_arn=arn1,
artifact_name="MyArtifact",
source=ArtifactSource(
source_uri="arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel",
source_types=None,
source_type="ARN",
source_id="Thu Dec 17 17:16:24 UTC 2020",
),
artifact_type="Model",
properties={
"PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\
pipeline/mypipeline/execution/0irnteql64d0",
"PipelineStepName": "MyStep",
"Status": "Completed",
},
creation_time=1608225384.0,
created_by={},
last_modified_time=1608225384.0,
last_modified_by={},
)
]

assert expected_model_list[0].artifact_arn == model_list[0].artifact_arn
assert expected_model_list[0].artifact_name == model_list[0].artifact_name
assert expected_model_list[0].source == model_list[0].source
assert expected_model_list[0].artifact_type == model_list[0].artifact_type
assert expected_model_list[0].artifact_type == "Model"
assert expected_model_list[0].properties == model_list[0].properties
assert expected_model_list[0].creation_time == model_list[0].creation_time
assert expected_model_list[0].created_by == model_list[0].created_by
assert expected_model_list[0].last_modified_time == model_list[0].last_modified_time
assert expected_model_list[0].last_modified_by == model_list[0].last_modified_by