@@ -49,64 +49,17 @@ def test_trial_component_name(viz, sagemaker_session):
49
49
"TrialComponentArn" : "tc-arn" ,
50
50
}
51
51
52
- sagemaker_session .sagemaker_client .list_associations .side_effect = [
53
- {
54
- "AssociationSummaries" : [
55
- {
56
- "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
57
- "SourceName" : "source-name-1" ,
58
- "SourceType" : "source-type-1" ,
59
- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
60
- "DestinationName" : "dest-name-1" ,
61
- "DestinationType" : "dest-type-1" ,
62
- "AssociationType" : "type-1" ,
63
- }
64
- ]
65
- },
66
- {
67
- "AssociationSummaries" : [
68
- {
69
- "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
70
- "SourceName" : "source-name-2" ,
71
- "SourceType" : "source-type-2" ,
72
- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
73
- "DestinationName" : "dest-name-2" ,
74
- "DestinationType" : "dest-type-2" ,
75
- "AssociationType" : "type-2" ,
76
- }
77
- ]
78
- },
79
- ]
52
+ get_list_associations_side_effect (sagemaker_session )
80
53
81
54
df = viz .show (trial_component_name = name )
82
55
83
56
sagemaker_session .sagemaker_client .describe_trial_component .assert_called_with (
84
57
TrialComponentName = name ,
85
58
)
86
59
87
- expected_calls = [
88
- unittest .mock .call (
89
- DestinationArn = "tc-arn" ,
90
- ),
91
- unittest .mock .call (
92
- SourceArn = "tc-arn" ,
93
- ),
94
- ]
95
- assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
60
+ assert_list_associations_mock_calls (sagemaker_session )
96
61
97
- expected_dataframe = pd .DataFrame .from_dict (
98
- OrderedDict (
99
- [
100
- ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
101
- ("Direction" , ["Input" , "Output" ]),
102
- ("Type" , ["source-type-1" , "dest-type-2" ]),
103
- ("Association Type" , ["type-1" , "type-2" ]),
104
- ("Lineage Type" , ["artifact" , "artifact" ]),
105
- ]
106
- )
107
- )
108
-
109
- pd .testing .assert_frame_equal (expected_dataframe , df )
62
+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
110
63
111
64
112
65
def test_model_package_arn (viz , sagemaker_session ):
@@ -116,34 +69,7 @@ def test_model_package_arn(viz, sagemaker_session):
116
69
"ArtifactSummaries" : [{"ArtifactArn" : "artifact-arn" }]
117
70
}
118
71
119
- sagemaker_session .sagemaker_client .list_associations .side_effect = [
120
- {
121
- "AssociationSummaries" : [
122
- {
123
- "SourceArn" : "a:b:c:d:e:artifact/src-arn-1" ,
124
- "SourceName" : "source-name-1" ,
125
- "SourceType" : "source-type-1" ,
126
- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-1" ,
127
- "DestinationName" : "dest-name-1" ,
128
- "DestinationType" : "dest-type-1" ,
129
- "AssociationType" : "type-1" ,
130
- }
131
- ]
132
- },
133
- {
134
- "AssociationSummaries" : [
135
- {
136
- "SourceArn" : "a:b:c:d:e:artifact/src-arn-2" ,
137
- "SourceName" : "source-name-2" ,
138
- "SourceType" : "source-type-2" ,
139
- "DestinationArn" : "a:b:c:d:e:artifact/dest-arn-2" ,
140
- "DestinationName" : "dest-name-2" ,
141
- "DestinationType" : "dest-type-2" ,
142
- "AssociationType" : "type-2" ,
143
- }
144
- ]
145
- },
146
- ]
72
+ get_list_associations_side_effect (sagemaker_session )
147
73
148
74
df = viz .show (model_package_arn = name )
149
75
@@ -161,19 +87,7 @@ def test_model_package_arn(viz, sagemaker_session):
161
87
]
162
88
assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
163
89
164
- expected_dataframe = pd .DataFrame .from_dict (
165
- OrderedDict (
166
- [
167
- ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
168
- ("Direction" , ["Input" , "Output" ]),
169
- ("Type" , ["source-type-1" , "dest-type-2" ]),
170
- ("Association Type" , ["type-1" , "type-2" ]),
171
- ("Lineage Type" , ["artifact" , "artifact" ]),
172
- ]
173
- )
174
- )
175
-
176
- pd .testing .assert_frame_equal (expected_dataframe , df )
90
+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
177
91
178
92
179
93
def test_endpoint_arn (viz , sagemaker_session ):
@@ -183,34 +97,7 @@ def test_endpoint_arn(viz, sagemaker_session):
183
97
"ContextSummaries" : [{"ContextArn" : "context-arn" }]
184
98
}
185
99
186
- sagemaker_session .sagemaker_client .list_associations .side_effect = [
187
- {
188
- "AssociationSummaries" : [
189
- {
190
- "SourceArn" : "a:b:c:d:e:context/src-arn-1" ,
191
- "SourceName" : "source-name-1" ,
192
- "SourceType" : "source-type-1" ,
193
- "DestinationArn" : "a:b:c:d:e:context/dest-arn-1" ,
194
- "DestinationName" : "dest-name-1" ,
195
- "DestinationType" : "dest-type-1" ,
196
- "AssociationType" : "type-1" ,
197
- }
198
- ]
199
- },
200
- {
201
- "AssociationSummaries" : [
202
- {
203
- "SourceArn" : "a:b:c:d:e:context/src-arn-2" ,
204
- "SourceName" : "source-name-2" ,
205
- "SourceType" : "source-type-2" ,
206
- "DestinationArn" : "a:b:c:d:e:context/dest-arn-2" ,
207
- "DestinationName" : "dest-name-2" ,
208
- "DestinationType" : "dest-type-2" ,
209
- "AssociationType" : "type-2" ,
210
- }
211
- ]
212
- },
213
- ]
100
+ get_list_associations_side_effect (sagemaker_session )
214
101
215
102
df = viz .show (endpoint_arn = name )
216
103
@@ -228,27 +115,74 @@ def test_endpoint_arn(viz, sagemaker_session):
228
115
]
229
116
assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
230
117
231
- expected_dataframe = pd .DataFrame .from_dict (
232
- OrderedDict (
233
- [
234
- ("Name/Source" , ["source-name-1" , "dest-name-2" ]),
235
- ("Direction" , ["Input" , "Output" ]),
236
- ("Type" , ["source-type-1" , "dest-type-2" ]),
237
- ("Association Type" , ["type-1" , "type-2" ]),
238
- ("Lineage Type" , ["context" , "context" ]),
239
- ]
240
- )
118
+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
119
+
120
+
121
+ def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
122
+
123
+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
124
+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
125
+ }
126
+
127
+ get_list_associations_side_effect (sagemaker_session )
128
+
129
+ step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
130
+
131
+ df = viz .show (pipeline_execution_step = step )
132
+
133
+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
134
+ SourceArn = "proc-job-arn" ,
241
135
)
242
136
243
- pd . testing . assert_frame_equal ( expected_dataframe , df )
137
+ assert_list_associations_mock_calls ( sagemaker_session )
244
138
139
+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
245
140
246
- def test_processing_job_pipeline_execution_step (viz , sagemaker_session ):
141
+
142
+ def test_training_job_pipeline_execution_step (viz , sagemaker_session ):
247
143
248
144
sagemaker_session .sagemaker_client .list_trial_components .return_value = {
249
145
"TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
250
146
}
251
147
148
+ get_list_associations_side_effect (sagemaker_session )
149
+
150
+ step = {"Metadata" : {"TrainingJob" : {"Arn" : "training-job-arn" }}}
151
+
152
+ df = viz .show (pipeline_execution_step = step )
153
+
154
+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
155
+ SourceArn = "training-job-arn" ,
156
+ )
157
+
158
+ assert_list_associations_mock_calls (sagemaker_session )
159
+
160
+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
161
+
162
+
163
+ def test_transform_job_pipeline_execution_step (viz , sagemaker_session ):
164
+
165
+ sagemaker_session .sagemaker_client .list_trial_components .return_value = {
166
+ "TrialComponentSummaries" : [{"TrialComponentArn" : "tc-arn" }]
167
+ }
168
+
169
+ get_list_associations_side_effect (sagemaker_session )
170
+
171
+ step = {"Metadata" : {"TransformJob" : {"Arn" : "transform-job-arn" }}}
172
+
173
+ df = viz .show (pipeline_execution_step = step )
174
+
175
+ sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
176
+ SourceArn = "transform-job-arn" ,
177
+ )
178
+
179
+ assert_list_associations_mock_calls (sagemaker_session )
180
+
181
+ pd .testing .assert_frame_equal (get_expected_dataframe (), df )
182
+
183
+
184
+ def get_list_associations_side_effect (sagemaker_session ):
185
+
252
186
sagemaker_session .sagemaker_client .list_associations .side_effect = [
253
187
{
254
188
"AssociationSummaries" : [
@@ -278,13 +212,8 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
278
212
},
279
213
]
280
214
281
- step = {"Metadata" : {"ProcessingJob" : {"Arn" : "proc-job-arn" }}}
282
-
283
- df = viz .show (pipeline_execution_step = step )
284
215
285
- sagemaker_session .sagemaker_client .list_trial_components .assert_called_with (
286
- SourceArn = "proc-job-arn" ,
287
- )
216
+ def assert_list_associations_mock_calls (sagemaker_session ):
288
217
289
218
expected_calls = [
290
219
unittest .mock .call (
@@ -296,6 +225,9 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
296
225
]
297
226
assert expected_calls == sagemaker_session .sagemaker_client .list_associations .mock_calls
298
227
228
+
229
+ def get_expected_dataframe ():
230
+
299
231
expected_dataframe = pd .DataFrame .from_dict (
300
232
OrderedDict (
301
233
[
@@ -308,4 +240,4 @@ def test_processing_job_pipeline_execution_step(viz, sagemaker_session):
308
240
)
309
241
)
310
242
311
- pd . testing . assert_frame_equal ( expected_dataframe , df )
243
+ return expected_dataframe
0 commit comments