Skip to content

Commit 75afd87

Browse files
authored
feature: Add models_v2 under lineage context (#2800)
1 parent 7efa99e commit 75afd87

File tree

4 files changed

+183
-8
lines changed

4 files changed

+183
-8
lines changed

src/sagemaker/lineage/context.py

+49
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,55 @@ def models(self) -> List[association.Association]:
283283
]
284284
return model_list
285285

286+
def models_v2(
287+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
288+
) -> List[Artifact]:
289+
"""Get artifacts representing models from the context lineage by querying lineage data.
290+
291+
Args:
292+
direction (LineageQueryDirectionEnum, optional): The query direction.
293+
294+
Returns:
295+
list of Artifacts: Artifacts representing a model.
296+
"""
297+
# Firstly query out the model_deployment vertices
298+
query_filter = LineageFilter(
299+
entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]
300+
)
301+
model_deployment_query_result = LineageQuery(self.sagemaker_session).query(
302+
start_arns=[self.context_arn],
303+
query_filter=query_filter,
304+
direction=direction,
305+
include_edges=False,
306+
)
307+
if not model_deployment_query_result:
308+
return []
309+
310+
model_deployment_vertices: [] = model_deployment_query_result.vertices
311+
312+
# Secondary query model based on model deployment
313+
model_vertices = []
314+
for vertex in model_deployment_vertices:
315+
query_result = LineageQuery(self.sagemaker_session).query(
316+
start_arns=[vertex.arn],
317+
query_filter=LineageFilter(
318+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]
319+
),
320+
direction=LineageQueryDirectionEnum.DESCENDANTS,
321+
include_edges=False,
322+
)
323+
model_vertices.extend(query_result.vertices)
324+
325+
if not model_vertices:
326+
return []
327+
328+
model_artifacts = []
329+
for vertex in model_vertices:
330+
lineage_object = vertex.to_lineage_object()
331+
model_artifacts.append(lineage_object)
332+
333+
return model_artifacts
334+
286335
def dataset_artifacts(
287336
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
288337
) -> List[Artifact]:

tests/integ/sagemaker/lineage/conftest.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from tests.integ.sagemaker.lineage.helpers import name, names
3737

3838
SLEEP_TIME_SECONDS = 1
39+
SLEEP_TIME_TWO_SECONDS = 2
3940
STATIC_PIPELINE_NAME = "SdkIntegTestStaticPipeline17"
4041
STATIC_ENDPOINT_NAME = "SdkIntegTestStaticEndpoint17"
4142

@@ -360,12 +361,10 @@ def endpoint_context_obj(sagemaker_session):
360361

361362
@pytest.fixture
362363
def model_obj(sagemaker_session):
363-
model = context.Context.create(
364-
context_name=name(),
364+
model = artifact.Artifact.create(
365+
artifact_name=name(),
366+
artifact_type="Model",
365367
source_uri="bar1",
366-
source_type="test-source-type1",
367-
context_type="Model",
368-
description="test-description",
369368
properties={"k1": "v1"},
370369
sagemaker_session=sagemaker_session,
371370
)
@@ -417,11 +416,12 @@ def endpoint_context_associate_with_model(sagemaker_session, endpoint_action_obj
417416

418417
association.Association.create(
419418
source_arn=endpoint_action_obj.action_arn,
420-
destination_arn=model_obj.context_arn,
419+
destination_arn=model_obj.artifact_arn,
421420
sagemaker_session=sagemaker_session,
422421
)
423422
yield obj
424-
time.sleep(SLEEP_TIME_SECONDS)
423+
# sleep 2 seconds since take longer for lineage injection
424+
time.sleep(SLEEP_TIME_TWO_SECONDS)
425425
obj.delete(disassociate=True)
426426

427427

tests/integ/sagemaker/lineage/test_endpoint_context.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,31 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains code to test SageMaker ``Contexts``"""
1414
from __future__ import absolute_import
15+
import time
16+
17+
SLEEP_TIME_ONE_SECONDS = 1
1518

1619

1720
def test_model(endpoint_context_associate_with_model, model_obj, endpoint_action_obj):
1821
model_list = endpoint_context_associate_with_model.models()
1922
for model in model_list:
2023
assert model.source_arn == endpoint_action_obj.action_arn
21-
assert model.destination_arn == model_obj.context_arn
24+
assert model.destination_arn == model_obj.artifact_arn
2225
assert model.source_type == "ModelDeployment"
2326
assert model.destination_type == "Model"
2427

2528

29+
def test_model_v2(endpoint_context_associate_with_model, model_obj, sagemaker_session):
30+
time.sleep(SLEEP_TIME_ONE_SECONDS)
31+
model_list = endpoint_context_associate_with_model.models_v2()
32+
assert len(model_list) == 1
33+
for model in model_list:
34+
assert model.artifact_arn == model_obj.artifact_arn
35+
assert model.artifact_name == model_obj.artifact_name
36+
assert model.artifact_type == "Model"
37+
assert model.properties == model_obj.properties
38+
39+
2640
def test_dataset_artifacts(static_endpoint_context):
2741
artifacts_from_query = static_endpoint_context.dataset_artifacts()
2842

tests/unit/sagemaker/lineage/test_endpoint_context.py

+112
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import unittest.mock
1616

1717
from sagemaker.lineage import context, _api_types
18+
from sagemaker.lineage._api_types import ArtifactSource
19+
from sagemaker.lineage.artifact import ModelArtifact
1820

1921

2022
def test_models(sagemaker_session):
@@ -75,3 +77,113 @@ def test_models(sagemaker_session):
7577
)
7678
]
7779
assert expected_model_list == model_list
80+
81+
82+
def test_models_v2(sagemaker_session):
83+
arn1 = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-integ-3b05f017-0d87-4c37"
84+
85+
obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=arn1)
86+
87+
sagemaker_session.sagemaker_client.query_lineage.return_value = {
88+
"Vertices": [
89+
{"Arn": arn1, "Type": "Model", "LineageType": "Artifact"},
90+
],
91+
"Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}],
92+
}
93+
94+
sagemaker_session.sagemaker_client.describe_context.return_value = {
95+
"ContextName": "MyContext",
96+
"ContextArn": arn1,
97+
"Source": {
98+
"SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:endpoint/myendpoint",
99+
"SourceType": "ARN",
100+
"SourceId": "Thu Dec 17 17:16:24 UTC 2020",
101+
},
102+
"ContextType": "Endpoint",
103+
"Properties": {
104+
"PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\
105+
pipeline/mypipeline/execution/0irnteql64d0",
106+
"PipelineStepName": "MyStep",
107+
"Status": "Completed",
108+
},
109+
"CreationTime": 1608225384.0,
110+
"CreatedBy": {},
111+
"LastModifiedTime": 1608225384.0,
112+
"LastModifiedBy": {},
113+
}
114+
115+
sagemaker_session.sagemaker_client.describe_artifact.return_value = {
116+
"ArtifactName": "MyArtifact",
117+
"ArtifactArn": arn1,
118+
"Source": {
119+
"SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel",
120+
"SourceType": "ARN",
121+
"SourceId": "Thu Dec 17 17:16:24 UTC 2020",
122+
},
123+
"ArtifactType": "Model",
124+
"Properties": {
125+
"PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\
126+
pipeline/mypipeline/execution/0irnteql64d0",
127+
"PipelineStepName": "MyStep",
128+
"Status": "Completed",
129+
},
130+
"CreationTime": 1608225384.0,
131+
"CreatedBy": {},
132+
"LastModifiedTime": 1608225384.0,
133+
"LastModifiedBy": {},
134+
}
135+
136+
model_list = obj.models_v2()
137+
138+
expected_calls = [
139+
unittest.mock.call(
140+
Direction="Descendants",
141+
Filters={"Types": ["ModelDeployment"], "LineageTypes": ["Action"]},
142+
IncludeEdges=False,
143+
MaxDepth=10,
144+
StartArns=[arn1],
145+
),
146+
unittest.mock.call(
147+
Direction="Descendants",
148+
Filters={"Types": ["Model"], "LineageTypes": ["Artifact"]},
149+
IncludeEdges=False,
150+
MaxDepth=10,
151+
StartArns=[arn1],
152+
),
153+
]
154+
assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls
155+
156+
expected_model_list = [
157+
ModelArtifact(
158+
artifact_arn=arn1,
159+
artifact_name="MyArtifact",
160+
source=ArtifactSource(
161+
source_uri="arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel",
162+
source_types=None,
163+
source_type="ARN",
164+
source_id="Thu Dec 17 17:16:24 UTC 2020",
165+
),
166+
artifact_type="Model",
167+
properties={
168+
"PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\
169+
pipeline/mypipeline/execution/0irnteql64d0",
170+
"PipelineStepName": "MyStep",
171+
"Status": "Completed",
172+
},
173+
creation_time=1608225384.0,
174+
created_by={},
175+
last_modified_time=1608225384.0,
176+
last_modified_by={},
177+
)
178+
]
179+
180+
assert expected_model_list[0].artifact_arn == model_list[0].artifact_arn
181+
assert expected_model_list[0].artifact_name == model_list[0].artifact_name
182+
assert expected_model_list[0].source == model_list[0].source
183+
assert expected_model_list[0].artifact_type == model_list[0].artifact_type
184+
assert expected_model_list[0].artifact_type == "Model"
185+
assert expected_model_list[0].properties == model_list[0].properties
186+
assert expected_model_list[0].creation_time == model_list[0].creation_time
187+
assert expected_model_list[0].created_by == model_list[0].created_by
188+
assert expected_model_list[0].last_modified_time == model_list[0].last_modified_time
189+
assert expected_model_list[0].last_modified_by == model_list[0].last_modified_by

0 commit comments

Comments
 (0)