Skip to content

Commit 7ce1d15

Browse files
authored
feature: Adds Lineage queries in artifact, context and trial components (#2838)
1 parent b82fb8a commit 7ce1d15

18 files changed

+1400
-29
lines changed

src/sagemaker/lineage/action.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
from typing import Optional, Iterator, List
1717
from datetime import datetime
1818

19-
from sagemaker import Session
19+
from sagemaker.session import Session
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
2323
from sagemaker.lineage.artifact import Artifact
24-
from sagemaker.lineage.context import Context
2524

2625
from sagemaker.lineage.query import (
2726
LineageQuery,
@@ -126,7 +125,7 @@ def delete(self, disassociate: bool = False):
126125
self._invoke_api(self._boto_delete_method, self._boto_delete_members)
127126

128127
@classmethod
129-
def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action":
128+
def load(cls, action_name: str, sagemaker_session=None) -> "Action":
130129
"""Load an existing action and return an ``Action`` object representing it.
131130
132131
Args:
@@ -324,7 +323,7 @@ def model_package(self):
324323

325324
def endpoints(
326325
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
327-
) -> List[Context]:
326+
) -> List:
328327
"""Use a lineage query to retrieve downstream endpoint contexts that use this action.
329328
330329
Args:

src/sagemaker/lineage/artifact.py

+129-4
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
143143
return artifact
144144

145145
def downstream_trials(self, sagemaker_session=None) -> list:
146-
"""Retrieve all trial runs which that use this artifact.
146+
"""Use the lineage API to retrieve all downstream trials that use this artifact.
147147
148148
Args:
149-
sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149+
sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150150
will be created.
151151
152152
Returns:
@@ -159,6 +159,54 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159159
)
160160
trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))
161161

162+
return self._get_trial_from_trial_component(trial_component_arns)
163+
164+
def downstream_trials_v2(self) -> list:
165+
"""Use a lineage query to retrieve all downstream trials that use this artifact.
166+
167+
Returns:
168+
[Trial]: A list of SageMaker `Trial` objects.
169+
"""
170+
return self._trials(direction=LineageQueryDirectionEnum.DESCENDANTS)
171+
172+
def upstream_trials(self) -> List:
173+
"""Use the lineage query to retrieve all upstream trials that use this artifact.
174+
175+
Returns:
176+
[Trial]: A list of SageMaker `Trial` objects.
177+
"""
178+
return self._trials(direction=LineageQueryDirectionEnum.ASCENDANTS)
179+
180+
def _trials(
181+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
182+
) -> List:
183+
"""Use the lineage query to retrieve all trials that use this artifact.
184+
185+
Args:
186+
direction (LineageQueryDirectionEnum, optional): The query direction.
187+
188+
Returns:
189+
[Trial]: A list of SageMaker `Trial` objects.
190+
"""
191+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
192+
query_result = LineageQuery(self.sagemaker_session).query(
193+
start_arns=[self.artifact_arn],
194+
query_filter=query_filter,
195+
direction=direction,
196+
include_edges=False,
197+
)
198+
trial_component_arns: list = list(map(lambda x: x.arn, query_result.vertices))
199+
return self._get_trial_from_trial_component(trial_component_arns)
200+
201+
def _get_trial_from_trial_component(self, trial_component_arns: list) -> List:
202+
"""Retrieve all upstream trial runs which that use the trial component arns.
203+
204+
Args:
205+
trial_component_arns (list): list of trial component arns
206+
207+
Returns:
208+
[Trial]: A list of SageMaker `Trial` objects.
209+
"""
162210
if not trial_component_arns:
163211
# no outgoing associations for this artifact
164212
return []
@@ -170,7 +218,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170218
num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
171219
trial_components: list = []
172220

173-
sagemaker_session = sagemaker_session or _utils.default_session()
221+
sagemaker_session = self.sagemaker_session or _utils.default_session()
174222
sagemaker_client = sagemaker_session.sagemaker_client
175223

176224
for i in range(num_search_batches):
@@ -335,6 +383,17 @@ def list(
335383
sagemaker_session=sagemaker_session,
336384
)
337385

386+
def s3_uri_artifacts(self, s3_uri: str) -> dict:
387+
"""Retrieve a list of artifacts that use provided s3 uri.
388+
389+
Args:
390+
s3_uri (str): A S3 URI.
391+
392+
Returns:
393+
A list of ``Artifacts``
394+
"""
395+
return self.sagemaker_session.sagemaker_client.list_artifacts(SourceUri=s3_uri)
396+
338397

339398
class ModelArtifact(Artifact):
340399
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +408,7 @@ def endpoints(self) -> list:
349408
"""Get association summaries for endpoints deployed with this model.
350409
351410
Returns:
352-
[AssociationSummary]: A list of associations repesenting the endpoints using the model.
411+
[AssociationSummary]: A list of associations representing the endpoints using the model.
353412
"""
354413
endpoint_development_actions: Iterator = Association.list(
355414
source_arn=self.artifact_arn,
@@ -522,3 +581,69 @@ def endpoint_contexts(
522581
for vertex in query_result.vertices:
523582
endpoint_contexts.append(vertex.to_lineage_object())
524583
return endpoint_contexts
584+
585+
def upstream_datasets(self) -> List[Artifact]:
586+
"""Use the lineage query to retrieve upstream artifacts that use this dataset artifact.
587+
588+
Returns:
589+
list of Artifacts: Artifacts representing an dataset.
590+
"""
591+
return self._datasets(direction=LineageQueryDirectionEnum.ASCENDANTS)
592+
593+
def downstream_datasets(self) -> List[Artifact]:
594+
"""Use the lineage query to retrieve downstream artifacts that use this dataset.
595+
596+
Returns:
597+
list of Artifacts: Artifacts representing an dataset.
598+
"""
599+
return self._datasets(direction=LineageQueryDirectionEnum.DESCENDANTS)
600+
601+
def _datasets(
602+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
603+
) -> List[Artifact]:
604+
"""Use the lineage query to retrieve all artifacts that use this dataset.
605+
606+
Args:
607+
direction (LineageQueryDirectionEnum, optional): The query direction.
608+
609+
Returns:
610+
list of Artifacts: Artifacts representing an dataset.
611+
"""
612+
query_filter = LineageFilter(
613+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
614+
)
615+
query_result = LineageQuery(self.sagemaker_session).query(
616+
start_arns=[self.artifact_arn],
617+
query_filter=query_filter,
618+
direction=direction,
619+
include_edges=False,
620+
)
621+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
622+
623+
624+
class ImageArtifact(Artifact):
625+
"""A SageMaker lineage artifact representing an image.
626+
627+
Common model specific lineage traversals to discover how the image is connected
628+
to other entities.
629+
"""
630+
631+
def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
632+
"""Use the lineage query to retrieve datasets that use this image artifact.
633+
634+
Args:
635+
direction (LineageQueryDirectionEnum): The query direction.
636+
637+
Returns:
638+
list of Artifacts: Artifacts representing a dataset.
639+
"""
640+
query_filter = LineageFilter(
641+
entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
642+
)
643+
query_result = LineageQuery(self.sagemaker_session).query(
644+
start_arns=[self.artifact_arn],
645+
query_filter=query_filter,
646+
direction=direction,
647+
include_edges=False,
648+
)
649+
return [vertex.to_lineage_object() for vertex in query_result.vertices]

src/sagemaker/lineage/context.py

+93-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
LineageQueryDirectionEnum,
3232
)
3333
from sagemaker.lineage.artifact import Artifact
34+
from sagemaker.lineage.action import Action
35+
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
3436

3537

3638
class Context(_base_types.Record):
@@ -256,12 +258,30 @@ def list(
256258
sagemaker_session=sagemaker_session,
257259
)
258260

261+
def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]:
262+
"""Use the lineage query to retrieve actions that use this context.
263+
264+
Args:
265+
direction (LineageQueryDirectionEnum): The query direction.
266+
267+
Returns:
268+
list of Actions: Actions.
269+
"""
270+
query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION])
271+
query_result = LineageQuery(self.sagemaker_session).query(
272+
start_arns=[self.context_arn],
273+
query_filter=query_filter,
274+
direction=direction,
275+
include_edges=False,
276+
)
277+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
278+
259279

260280
class EndpointContext(Context):
261281
"""An Amazon SageMaker endpoint context, which is part of a SageMaker lineage."""
262282

263283
def models(self) -> List[association.Association]:
264-
"""Get all models deployed by all endpoint versions of the endpoint.
284+
"""Use Lineage API to get all models deployed by this endpoint.
265285
266286
Returns:
267287
list of Associations: Associations that destination represents an endpoint's model.
@@ -286,7 +306,7 @@ def models(self) -> List[association.Association]:
286306
def models_v2(
287307
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
288308
) -> List[Artifact]:
289-
"""Get artifacts representing models from the context lineage by querying lineage data.
309+
"""Use the lineage query to retrieve downstream model artifacts that use this endpoint.
290310
291311
Args:
292312
direction (LineageQueryDirectionEnum, optional): The query direction.
@@ -335,7 +355,7 @@ def models_v2(
335355
def dataset_artifacts(
336356
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
337357
) -> List[Artifact]:
338-
"""Get artifacts representing datasets from the endpoint's lineage.
358+
"""Use the lineage query to retrieve datasets that use this endpoint.
339359
340360
Args:
341361
direction (LineageQueryDirectionEnum, optional): The query direction.
@@ -360,6 +380,9 @@ def training_job_arns(
360380
) -> List[str]:
361381
"""Get ARNs for all training jobs that appear in the endpoint's lineage.
362382
383+
Args:
384+
direction (LineageQueryDirectionEnum, optional): The query direction.
385+
363386
Returns:
364387
list of str: Training job ARNs.
365388
"""
@@ -382,11 +405,78 @@ def training_job_arns(
382405
training_job_arns.append(trial_component["Source"]["SourceArn"])
383406
return training_job_arns
384407

408+
def processing_jobs(
409+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
410+
) -> List[LineageTrialComponent]:
411+
"""Use the lineage query to retrieve processing jobs that use this endpoint.
412+
413+
Args:
414+
direction (LineageQueryDirectionEnum, optional): The query direction.
415+
416+
Returns:
417+
list of LineageTrialComponent: Lineage trial component that represent Processing jobs.
418+
"""
419+
query_filter = LineageFilter(
420+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]
421+
)
422+
query_result = LineageQuery(self.sagemaker_session).query(
423+
start_arns=[self.context_arn],
424+
query_filter=query_filter,
425+
direction=direction,
426+
include_edges=False,
427+
)
428+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
429+
430+
def transform_jobs(
431+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
432+
) -> List[LineageTrialComponent]:
433+
"""Use the lineage query to retrieve transform jobs that use this endpoint.
434+
435+
Args:
436+
direction (LineageQueryDirectionEnum, optional): The query direction.
437+
438+
Returns:
439+
list of LineageTrialComponent: Lineage trial component that represent Transform jobs.
440+
"""
441+
query_filter = LineageFilter(
442+
entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB]
443+
)
444+
query_result = LineageQuery(self.sagemaker_session).query(
445+
start_arns=[self.context_arn],
446+
query_filter=query_filter,
447+
direction=direction,
448+
include_edges=False,
449+
)
450+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
451+
452+
def trial_components(
453+
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
454+
) -> List[LineageTrialComponent]:
455+
"""Use the lineage query to retrieve trial components that use this endpoint.
456+
457+
Args:
458+
direction (LineageQueryDirectionEnum, optional): The query direction.
459+
460+
Returns:
461+
list of LineageTrialComponent: Lineage trial component.
462+
"""
463+
query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT])
464+
query_result = LineageQuery(self.sagemaker_session).query(
465+
start_arns=[self.context_arn],
466+
query_filter=query_filter,
467+
direction=direction,
468+
include_edges=False,
469+
)
470+
return [vertex.to_lineage_object() for vertex in query_result.vertices]
471+
385472
def pipeline_execution_arn(
386473
self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
387474
) -> str:
388475
"""Get the ARN for the pipeline execution associated with this endpoint (if any).
389476
477+
Args:
478+
direction (LineageQueryDirectionEnum, optional): The query direction.
479+
390480
Returns:
391481
str: A pipeline execution ARN.
392482
"""

0 commit comments

Comments
 (0)