diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index f2d1bf8c14..a1ab295b05 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -83,10 +83,11 @@ def __init__( self._session = sagemaker_session def to_lineage_object(self): - """Convert the ``Vertex`` object to its corresponding ``Artifact`` or ``Context`` object.""" + """Convert the ``Vertex`` object to its corresponding Artifact, Action, Context object.""" from sagemaker.lineage.artifact import Artifact, ModelArtifact from sagemaker.lineage.context import Context, EndpointContext from sagemaker.lineage.artifact import DatasetArtifact + from sagemaker.lineage.action import Action if self.lineage_entity == LineageEntityEnum.CONTEXT.value: resource_name = get_resource_name_from_arn(self.arn) @@ -103,6 +104,9 @@ def to_lineage_object(self): return DatasetArtifact.load(artifact_arn=self.arn, sagemaker_session=self._session) return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) + if self.lineage_entity == LineageEntityEnum.ACTION.value: + return Action.load(action_name=self.arn.split("/")[1], sagemaker_session=self._session) + raise ValueError("Vertex cannot be converted to a lineage object.") diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 17d3eabe92..c25ca6f38f 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -13,6 +13,7 @@ from __future__ import absolute_import from sagemaker.lineage.artifact import DatasetArtifact, ModelArtifact, Artifact from sagemaker.lineage.context import EndpointContext, Context +from sagemaker.lineage.action import Action from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest @@ -240,10 +241,38 @@ def test_vertex_to_object_artifact(sagemaker_session): assert isinstance(artifact, Artifact) +def test_vertex_to_object_action(sagemaker_session): + vertex = Vertex( + arn="arn:aws:sagemaker:us-west-2:0123456789012:action/cp-m5-20210424t041405868z-1619237657-1-aws-endpoint", + lineage_entity=LineageEntityEnum.ACTION.value, + lineage_source="A", + sagemaker_session=sagemaker_session, + ) + + sagemaker_session.sagemaker_client.describe_action.return_value = { + "ActionName": "cp-m5-20210424t041405868z-1619237657-1-aws-endpoint", + "Source": { + "SourceUri": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3", + "SourceTypes": [], + }, + "ActionType": "A", + "Properties": {}, + "CreationTime": 1608224704.149, + "CreatedBy": {}, + "LastModifiedTime": 1608224704.149, + "LastModifiedBy": {}, + } + + action = vertex.to_lineage_object() + + assert action.action_name == "cp-m5-20210424t041405868z-1619237657-1-aws-endpoint" + assert isinstance(action, Action) + + def test_vertex_to_object_unconvertable(sagemaker_session): vertex = Vertex( arn="arn:aws:sagemaker:us-west-2:0123456789012:artifact/e66eef7f19c05e75284089183491bd4f", - lineage_entity=LineageEntityEnum.ACTION.value, + lineage_entity=LineageEntityEnum.TRIAL_COMPONENT.value, lineage_source=LineageSourceEnum.TENSORBOARD.value, sagemaker_session=sagemaker_session, )