1
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
1
+ # Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License"). You
4
4
# may not use this file except in compliance with the License. A copy of
14
14
15
15
import unittest .mock
16
16
17
- import pytest
18
- from sagemaker .lineage import visualizer
19
17
import pandas as pd
20
18
from collections import OrderedDict
21
19
22
20
23
- @pytest .fixture
24
- def sagemaker_session ():
25
- return unittest .mock .Mock ()
26
-
27
-
28
- @pytest .fixture
29
- def vizualizer (sagemaker_session ):
30
- return visualizer .LineageTableVisualizer (sagemaker_session )
31
-
32
-
33
- def test_friendly_name_short_uri (vizualizer , sagemaker_session ):
21
+ def test_friendly_name_short_uri (viz , sagemaker_session ):
34
22
uri = "s3://f-069083975568/train.txt"
35
23
arn = "test_arn"
36
24
sagemaker_session .sagemaker_client .describe_artifact .return_value = {
37
25
"Source" : {"SourceUri" : uri , "SourceTypes" : "" }
38
26
}
39
- actual_name = vizualizer ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
27
+ actual_name = viz ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
40
28
assert uri == actual_name
41
29
42
30
43
- def test_friendly_name_long_uri (vizualizer , sagemaker_session ):
31
+ def test_friendly_name_long_uri (viz , sagemaker_session ):
44
32
uri = (
45
33
"s3://flintstone-end-to-end-tests-gamma-us-west-2-069083975568/results/canary-auto-1608761252626/"
46
34
"preprocessed-data/tuning_data/train.txt"
@@ -49,12 +37,12 @@ def test_friendly_name_long_uri(vizualizer, sagemaker_session):
49
37
sagemaker_session .sagemaker_client .describe_artifact .return_value = {
50
38
"Source" : {"SourceUri" : uri , "SourceTypes" : "" }
51
39
}
52
- actual_name = vizualizer ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
40
+ actual_name = viz ._get_friendly_name (name = None , arn = arn , entity_type = "artifact" )
53
41
expected_name = "s3://.../preprocessed-data/tuning_data/train.txt"
54
42
assert expected_name == actual_name
55
43
56
44
57
- def test_trial_component_name (sagemaker_session , vizualizer ):
45
+ def test_trial_component_name (viz , sagemaker_session ):
58
46
name = "tc-name"
59
47
60
48
sagemaker_session .sagemaker_client .describe_trial_component .return_value = {
@@ -90,7 +78,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
90
78
},
91
79
]
92
80
93
- df = vizualizer .show (trial_component_name = name )
81
+ df = viz .show (trial_component_name = name )
94
82
95
83
sagemaker_session .sagemaker_client .describe_trial_component .assert_called_with (
96
84
TrialComponentName = name ,
@@ -121,7 +109,7 @@ def test_trial_component_name(sagemaker_session, vizualizer):
121
109
pd .testing .assert_frame_equal (expected_dataframe , df )
122
110
123
111
124
- def test_model_package_arn (sagemaker_session , vizualizer ):
112
+ def test_model_package_arn (viz , sagemaker_session ):
125
113
name = "model_package_arn"
126
114
127
115
sagemaker_session .sagemaker_client .list_artifacts .return_value = {
@@ -157,7 +145,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
157
145
},
158
146
]
159
147
160
- df = vizualizer .show (model_package_arn = name )
148
+ df = viz .show (model_package_arn = name )
161
149
162
150
sagemaker_session .sagemaker_client .list_artifacts .assert_called_with (
163
151
SourceUri = name ,
@@ -188,7 +176,7 @@ def test_model_package_arn(sagemaker_session, vizualizer):
188
176
pd .testing .assert_frame_equal (expected_dataframe , df )
189
177
190
178
191
- def test_endpoint_arn (sagemaker_session , vizualizer ):
179
+ def test_endpoint_arn (viz , sagemaker_session ):
192
180
name = "endpoint_arn"
193
181
194
182
sagemaker_session .sagemaker_client .list_contexts .return_value = {
@@ -224,7 +212,7 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
224
212
},
225
213
]
226
214
227
- df = vizualizer .show (endpoint_arn = name )
215
+ df = viz .show (endpoint_arn = name )
228
216
229
217
sagemaker_session .sagemaker_client .list_contexts .assert_called_with (
230
218
SourceUri = name ,
@@ -253,3 +241,71 @@ def test_endpoint_arn(sagemaker_session, vizualizer):
253
241
)
254
242
255
243
pd .testing .assert_frame_equal (expected_dataframe , df )
244
+
245
+
246
+ def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
247
+
248
+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
249
+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
250
+ }
251
+
252
+ sagemaker_session .sagemaker_client .list_associations .side_effect = [
253
+ {
254
+ "AssociationSummaries" : [
255
+ {
256
+ "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
257
+ "SourceName" : "source-name-1" ,
258
+ "SourceType" : "source-type-1" ,
259
+ "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
260
+ "DestinationName" : "dest-name-1" ,
261
+ "DestinationType" : "dest-type-1" ,
262
+ "AssociationType" : "type-1" ,
263
+ }
264
+ ]
265
+ },
266
+ {
267
+ "AssociationSummaries" : [
268
+ {
269
+ "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
270
+ "SourceName" : "source-name-2" ,
271
+ "SourceType" : "source-type-2" ,
272
+ "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
273
+ "DestinationName" : "dest-name-2" ,
274
+ "DestinationType" : "dest-type-2" ,
275
+ "AssociationType" : "type-2" ,
276
+ }
277
+ ]
278
+ },
279
+ ]
280
+
281
+ step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
282
+
283
+ df = viz .show (pipeline_execution_step = step )
284
+
285
+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
286
+ SourceArn = "proc-job-arn" ,
287
+ )
288
+
289
+ expected_calls = [
290
+ unittest .mock .call (
291
+ DestinationArn = "tc-arn" ,
292
+ ),
293
+ unittest .mock .call (
294
+ SourceArn = "tc-arn" ,
295
+ ),
296
+ ]
297
+ assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
298
+
299
+ expected_dataframe = pd .DataFrame .from_dict (
300
+ OrderedDict (
301
+ [
302
+ ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
303
+ ("Direction" , ["Input" , "Output" ]),
304
+ ("Type" , ["source-type-1" , "dest-type-2" ]),
305
+ ("Association Type" , ["type-1" , "type-2" ]),
306
+ ("Lineage Type" , ["artifact" , "artifact" ]),
307
+ ]
308
+ )
309
+ )
310
+
311
+ pd .testing .assert_frame_equal (expected_dataframe , df )
0 commit comments