|
13 | 13 | """This module contains code to create and manage SageMaker ``Actions``."""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 |
| -from typing import Optional, Iterator |
| 16 | +from typing import Optional, Iterator, List |
17 | 17 | from datetime import datetime
|
18 | 18 |
|
19 | 19 | from sagemaker import Session
|
20 | 20 | from sagemaker.apiutils import _base_types
|
21 | 21 | from sagemaker.lineage import _api_types, _utils
|
22 | 22 | from sagemaker.lineage._api_types import ActionSource, ActionSummary
|
| 23 | +from sagemaker.lineage.artifact import Artifact |
| 24 | + |
| 25 | +from sagemaker.lineage.query import ( |
| 26 | + LineageQuery, |
| 27 | + LineageFilter, |
| 28 | + LineageSourceEnum, |
| 29 | + LineageEntityEnum, |
| 30 | + LineageQueryDirectionEnum, |
| 31 | +) |
23 | 32 |
|
24 | 33 |
|
25 | 34 | class Action(_base_types.Record):
|
@@ -250,3 +259,86 @@ def list(
|
250 | 259 | max_results=max_results,
|
251 | 260 | next_token=next_token,
|
252 | 261 | )
|
| 262 | + |
| 263 | + def artifacts( |
| 264 | + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH |
| 265 | + ) -> List[Artifact]: |
| 266 | + """Use a lineage query to retrieve all artifacts that use this action. |
| 267 | +
|
| 268 | + Args: |
| 269 | + direction (LineageQueryDirectionEnum, optional): The query direction. |
| 270 | +
|
| 271 | + Returns: |
| 272 | + list of Artifacts: Artifacts. |
| 273 | + """ |
| 274 | + query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT]) |
| 275 | + query_result = LineageQuery(self.sagemaker_session).query( |
| 276 | + start_arns=[self.action_arn], |
| 277 | + query_filter=query_filter, |
| 278 | + direction=direction, |
| 279 | + include_edges=False, |
| 280 | + ) |
| 281 | + return [vertex.to_lineage_object() for vertex in query_result.vertices] |
| 282 | + |
| 283 | + |
| 284 | +class ModelPackageApprovalAction(Action): |
| 285 | + """An Amazon SageMaker model package approval action, which is part of a SageMaker lineage.""" |
| 286 | + |
| 287 | + def datasets( |
| 288 | + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS |
| 289 | + ) -> List[Artifact]: |
| 290 | + """Use a lineage query to retrieve all upstream datasets that use this action. |
| 291 | +
|
| 292 | + Args: |
| 293 | + direction (LineageQueryDirectionEnum, optional): The query direction. |
| 294 | +
|
| 295 | + Returns: |
| 296 | + list of Artifacts: Artifacts representing a dataset. |
| 297 | + """ |
| 298 | + query_filter = LineageFilter( |
| 299 | + entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET] |
| 300 | + ) |
| 301 | + query_result = LineageQuery(self.sagemaker_session).query( |
| 302 | + start_arns=[self.action_arn], |
| 303 | + query_filter=query_filter, |
| 304 | + direction=direction, |
| 305 | + include_edges=False, |
| 306 | + ) |
| 307 | + return [vertex.to_lineage_object() for vertex in query_result.vertices] |
| 308 | + |
| 309 | + def model_package(self): |
| 310 | + """Get model package from model package approval action. |
| 311 | +
|
| 312 | + Returns: |
| 313 | + Model package. |
| 314 | + """ |
| 315 | + source_uri = self.source.source_uri |
| 316 | + if source_uri is None: |
| 317 | + return None |
| 318 | + |
| 319 | + model_package_name = source_uri.split("/")[1] |
| 320 | + return self.sagemaker_session.sagemaker_client.describe_model_package( |
| 321 | + ModelPackageName=model_package_name |
| 322 | + ) |
| 323 | + |
| 324 | + def endpoints( |
| 325 | + self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS |
| 326 | + ): |
| 327 | + """Use a lineage query to retrieve downstream endpoint contexts that use this action. |
| 328 | +
|
| 329 | + Args: |
| 330 | + direction (LineageQueryDirectionEnum, optional): The query direction. |
| 331 | +
|
| 332 | + Returns: |
| 333 | + list of Contexts: Contexts representing an endpoint. |
| 334 | + """ |
| 335 | + query_filter = LineageFilter( |
| 336 | + entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT] |
| 337 | + ) |
| 338 | + query_result = LineageQuery(self.sagemaker_session).query( |
| 339 | + start_arns=[self.action_arn], |
| 340 | + query_filter=query_filter, |
| 341 | + direction=direction, |
| 342 | + include_edges=False, |
| 343 | + ) |
| 344 | + return [vertex.to_lineage_object() for vertex in query_result.vertices] |
0 commit comments