@@ -143,10 +143,10 @@ def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact":
143
143
return artifact
144
144
145
145
def downstream_trials (self , sagemaker_session = None ) -> list :
146
- """Retrieve all trial runs which that use this artifact.
146
+ """Use the lineage API to retrieve all downstream trials that use this artifact.
147
147
148
148
Args:
149
- sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
149
+ sagemaker_session (obj): Sagemaker Session to use. If not provided a default session
150
150
will be created.
151
151
152
152
Returns:
@@ -159,6 +159,54 @@ def downstream_trials(self, sagemaker_session=None) -> list:
159
159
)
160
160
trial_component_arns : list = list (map (lambda x : x .destination_arn , outgoing_associations ))
161
161
162
+ return self ._get_trial_from_trial_component (trial_component_arns )
163
+
164
+ def downstream_trials_v2 (self ) -> list :
165
+ """Use a lineage query to retrieve all downstream trials that use this artifact.
166
+
167
+ Returns:
168
+ [Trial]: A list of SageMaker `Trial` objects.
169
+ """
170
+ return self ._trials (direction = LineageQueryDirectionEnum .DESCENDANTS )
171
+
172
+ def upstream_trials (self ) -> List :
173
+ """Use the lineage query to retrieve all upstream trials that use this artifact.
174
+
175
+ Returns:
176
+ [Trial]: A list of SageMaker `Trial` objects.
177
+ """
178
+ return self ._trials (direction = LineageQueryDirectionEnum .ASCENDANTS )
179
+
180
+ def _trials (
181
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .BOTH
182
+ ) -> List :
183
+ """Use the lineage query to retrieve all trials that use this artifact.
184
+
185
+ Args:
186
+ direction (LineageQueryDirectionEnum, optional): The query direction.
187
+
188
+ Returns:
189
+ [Trial]: A list of SageMaker `Trial` objects.
190
+ """
191
+ query_filter = LineageFilter (entities = [LineageEntityEnum .TRIAL_COMPONENT ])
192
+ query_result = LineageQuery (self .sagemaker_session ).query (
193
+ start_arns = [self .artifact_arn ],
194
+ query_filter = query_filter ,
195
+ direction = direction ,
196
+ include_edges = False ,
197
+ )
198
+ trial_component_arns : list = list (map (lambda x : x .arn , query_result .vertices ))
199
+ return self ._get_trial_from_trial_component (trial_component_arns )
200
+
201
+ def _get_trial_from_trial_component (self , trial_component_arns : list ) -> List :
202
+ """Retrieve all upstream trial runs which that use the trial component arns.
203
+
204
+ Args:
205
+ trial_component_arns (list): list of trial component arns
206
+
207
+ Returns:
208
+ [Trial]: A list of SageMaker `Trial` objects.
209
+ """
162
210
if not trial_component_arns :
163
211
# no outgoing associations for this artifact
164
212
return []
@@ -170,7 +218,7 @@ def downstream_trials(self, sagemaker_session=None) -> list:
170
218
num_search_batches = math .ceil (len (trial_component_arns ) % max_search_by_arn )
171
219
trial_components : list = []
172
220
173
- sagemaker_session = sagemaker_session or _utils .default_session ()
221
+ sagemaker_session = self . sagemaker_session or _utils .default_session ()
174
222
sagemaker_client = sagemaker_session .sagemaker_client
175
223
176
224
for i in range (num_search_batches ):
@@ -335,6 +383,17 @@ def list(
335
383
sagemaker_session = sagemaker_session ,
336
384
)
337
385
386
+ def s3_uri_artifacts (self , s3_uri : str ) -> dict :
387
+ """Retrieve a list of artifacts that use provided s3 uri.
388
+
389
+ Args:
390
+ s3_uri (str): A S3 URI.
391
+
392
+ Returns:
393
+ A list of ``Artifacts``
394
+ """
395
+ return self .sagemaker_session .sagemaker_client .list_artifacts (SourceUri = s3_uri )
396
+
338
397
339
398
class ModelArtifact (Artifact ):
340
399
"""A SageMaker lineage artifact representing a model.
@@ -349,7 +408,7 @@ def endpoints(self) -> list:
349
408
"""Get association summaries for endpoints deployed with this model.
350
409
351
410
Returns:
352
- [AssociationSummary]: A list of associations repesenting the endpoints using the model.
411
+ [AssociationSummary]: A list of associations representing the endpoints using the model.
353
412
"""
354
413
endpoint_development_actions : Iterator = Association .list (
355
414
source_arn = self .artifact_arn ,
@@ -522,3 +581,69 @@ def endpoint_contexts(
522
581
for vertex in query_result .vertices :
523
582
endpoint_contexts .append (vertex .to_lineage_object ())
524
583
return endpoint_contexts
584
+
585
+ def upstream_datasets (self ) -> List [Artifact ]:
586
+ """Use the lineage query to retrieve upstream artifacts that use this dataset artifact.
587
+
588
+ Returns:
589
+ list of Artifacts: Artifacts representing an dataset.
590
+ """
591
+ return self ._datasets (direction = LineageQueryDirectionEnum .ASCENDANTS )
592
+
593
+ def downstream_datasets (self ) -> List [Artifact ]:
594
+ """Use the lineage query to retrieve downstream artifacts that use this dataset.
595
+
596
+ Returns:
597
+ list of Artifacts: Artifacts representing an dataset.
598
+ """
599
+ return self ._datasets (direction = LineageQueryDirectionEnum .DESCENDANTS )
600
+
601
+ def _datasets (
602
+ self , direction : LineageQueryDirectionEnum = LineageQueryDirectionEnum .BOTH
603
+ ) -> List [Artifact ]:
604
+ """Use the lineage query to retrieve all artifacts that use this dataset.
605
+
606
+ Args:
607
+ direction (LineageQueryDirectionEnum, optional): The query direction.
608
+
609
+ Returns:
610
+ list of Artifacts: Artifacts representing an dataset.
611
+ """
612
+ query_filter = LineageFilter (
613
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
614
+ )
615
+ query_result = LineageQuery (self .sagemaker_session ).query (
616
+ start_arns = [self .artifact_arn ],
617
+ query_filter = query_filter ,
618
+ direction = direction ,
619
+ include_edges = False ,
620
+ )
621
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
622
+
623
+
624
+ class ImageArtifact (Artifact ):
625
+ """A SageMaker lineage artifact representing an image.
626
+
627
+ Common model specific lineage traversals to discover how the image is connected
628
+ to other entities.
629
+ """
630
+
631
+ def datasets (self , direction : LineageQueryDirectionEnum ) -> List [Artifact ]:
632
+ """Use the lineage query to retrieve datasets that use this image artifact.
633
+
634
+ Args:
635
+ direction (LineageQueryDirectionEnum): The query direction.
636
+
637
+ Returns:
638
+ list of Artifacts: Artifacts representing a dataset.
639
+ """
640
+ query_filter = LineageFilter (
641
+ entities = [LineageEntityEnum .ARTIFACT ], sources = [LineageSourceEnum .DATASET ]
642
+ )
643
+ query_result = LineageQuery (self .sagemaker_session ).query (
644
+ start_arns = [self .artifact_arn ],
645
+ query_filter = query_filter ,
646
+ direction = direction ,
647
+ include_edges = False ,
648
+ )
649
+ return [vertex .to_lineage_object () for vertex in query_result .vertices ]
0 commit comments