|
15 | 15 | import unittest.mock
|
16 | 16 |
|
17 | 17 | from sagemaker.lineage import context, _api_types
|
| 18 | +from sagemaker.lineage._api_types import ArtifactSource |
| 19 | +from sagemaker.lineage.artifact import ModelArtifact |
18 | 20 |
|
19 | 21 |
|
20 | 22 | def test_models(sagemaker_session):
|
@@ -75,3 +77,113 @@ def test_models(sagemaker_session):
|
75 | 77 | )
|
76 | 78 | ]
|
77 | 79 | assert expected_model_list == model_list
|
| 80 | + |
| 81 | + |
| 82 | +def test_models_v2(sagemaker_session): |
| 83 | + arn1 = "arn:aws:sagemaker:us-west-2:123456789012:context/lineage-integ-3b05f017-0d87-4c37" |
| 84 | + |
| 85 | + obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn=arn1) |
| 86 | + |
| 87 | + sagemaker_session.sagemaker_client.query_lineage.return_value = { |
| 88 | + "Vertices": [ |
| 89 | + {"Arn": arn1, "Type": "Model", "LineageType": "Artifact"}, |
| 90 | + ], |
| 91 | + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], |
| 92 | + } |
| 93 | + |
| 94 | + sagemaker_session.sagemaker_client.describe_context.return_value = { |
| 95 | + "ContextName": "MyContext", |
| 96 | + "ContextArn": arn1, |
| 97 | + "Source": { |
| 98 | + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:endpoint/myendpoint", |
| 99 | + "SourceType": "ARN", |
| 100 | + "SourceId": "Thu Dec 17 17:16:24 UTC 2020", |
| 101 | + }, |
| 102 | + "ContextType": "Endpoint", |
| 103 | + "Properties": { |
| 104 | + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ |
| 105 | + pipeline/mypipeline/execution/0irnteql64d0", |
| 106 | + "PipelineStepName": "MyStep", |
| 107 | + "Status": "Completed", |
| 108 | + }, |
| 109 | + "CreationTime": 1608225384.0, |
| 110 | + "CreatedBy": {}, |
| 111 | + "LastModifiedTime": 1608225384.0, |
| 112 | + "LastModifiedBy": {}, |
| 113 | + } |
| 114 | + |
| 115 | + sagemaker_session.sagemaker_client.describe_artifact.return_value = { |
| 116 | + "ArtifactName": "MyArtifact", |
| 117 | + "ArtifactArn": arn1, |
| 118 | + "Source": { |
| 119 | + "SourceUri": "arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel", |
| 120 | + "SourceType": "ARN", |
| 121 | + "SourceId": "Thu Dec 17 17:16:24 UTC 2020", |
| 122 | + }, |
| 123 | + "ArtifactType": "Model", |
| 124 | + "Properties": { |
| 125 | + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ |
| 126 | + pipeline/mypipeline/execution/0irnteql64d0", |
| 127 | + "PipelineStepName": "MyStep", |
| 128 | + "Status": "Completed", |
| 129 | + }, |
| 130 | + "CreationTime": 1608225384.0, |
| 131 | + "CreatedBy": {}, |
| 132 | + "LastModifiedTime": 1608225384.0, |
| 133 | + "LastModifiedBy": {}, |
| 134 | + } |
| 135 | + |
| 136 | + model_list = obj.models_v2() |
| 137 | + |
| 138 | + expected_calls = [ |
| 139 | + unittest.mock.call( |
| 140 | + Direction="Descendants", |
| 141 | + Filters={"Types": ["ModelDeployment"], "LineageTypes": ["Action"]}, |
| 142 | + IncludeEdges=False, |
| 143 | + MaxDepth=10, |
| 144 | + StartArns=[arn1], |
| 145 | + ), |
| 146 | + unittest.mock.call( |
| 147 | + Direction="Descendants", |
| 148 | + Filters={"Types": ["Model"], "LineageTypes": ["Artifact"]}, |
| 149 | + IncludeEdges=False, |
| 150 | + MaxDepth=10, |
| 151 | + StartArns=[arn1], |
| 152 | + ), |
| 153 | + ] |
| 154 | + assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls |
| 155 | + |
| 156 | + expected_model_list = [ |
| 157 | + ModelArtifact( |
| 158 | + artifact_arn=arn1, |
| 159 | + artifact_name="MyArtifact", |
| 160 | + source=ArtifactSource( |
| 161 | + source_uri="arn:aws:sagemaker:us-west-2:0123456789012:model/mymodel", |
| 162 | + source_types=None, |
| 163 | + source_type="ARN", |
| 164 | + source_id="Thu Dec 17 17:16:24 UTC 2020", |
| 165 | + ), |
| 166 | + artifact_type="Model", |
| 167 | + properties={ |
| 168 | + "PipelineExecutionArn": "arn:aws:sagemaker:us-west-2:0123456789012:\ |
| 169 | + pipeline/mypipeline/execution/0irnteql64d0", |
| 170 | + "PipelineStepName": "MyStep", |
| 171 | + "Status": "Completed", |
| 172 | + }, |
| 173 | + creation_time=1608225384.0, |
| 174 | + created_by={}, |
| 175 | + last_modified_time=1608225384.0, |
| 176 | + last_modified_by={}, |
| 177 | + ) |
| 178 | + ] |
| 179 | + |
| 180 | + assert expected_model_list[0].artifact_arn == model_list[0].artifact_arn |
| 181 | + assert expected_model_list[0].artifact_name == model_list[0].artifact_name |
| 182 | + assert expected_model_list[0].source == model_list[0].source |
| 183 | + assert expected_model_list[0].artifact_type == model_list[0].artifact_type |
| 184 | + assert expected_model_list[0].artifact_type == "Model" |
| 185 | + assert expected_model_list[0].properties == model_list[0].properties |
| 186 | + assert expected_model_list[0].creation_time == model_list[0].creation_time |
| 187 | + assert expected_model_list[0].created_by == model_list[0].created_by |
| 188 | + assert expected_model_list[0].last_modified_time == model_list[0].last_modified_time |
| 189 | + assert expected_model_list[0].last_modified_by == model_list[0].last_modified_by |
0 commit comments