@@ -114,19 +114,39 @@ def test_pipeline_execution_arn(sagemaker_session):
114
114
trial_component_arn = (
115
115
"arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
116
116
)
117
- obj = lineage_trial_component . LineageTrialComponent (
118
- sagemaker_session , trial_component_name = "foo" , trial_component_arn = trial_component_arn
117
+ training_job_arn = (
118
+ "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
119
119
)
120
+ context = lineage_trial_component .LineageTrialComponent (
121
+ sagemaker_session ,
122
+ trial_component_name = "foo" ,
123
+ trial_component_arn = trial_component_arn ,
124
+ source = {
125
+ "SourceArn" : training_job_arn ,
126
+ "SourceType" : "SageMakerTrainingJob" ,
127
+ },
128
+ )
129
+ obj = {
130
+ "TrialComponentName" : "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job" ,
131
+ "TrialComponentArn" : trial_component_arn ,
132
+ "DisplayName" : "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job" ,
133
+ "Source" : {
134
+ "SourceArn" : training_job_arn ,
135
+ "SourceType" : "SageMakerTrainingJob" ,
136
+ },
137
+ }
138
+ sagemaker_session .sagemaker_client .describe_trial_component .return_value = obj
120
139
121
140
sagemaker_session .sagemaker_client .list_tags .return_value = {
122
141
"Tags" : [
123
142
{"Key" : "sagemaker:pipeline-execution-arn" , "Value" : "tag1" },
124
143
],
125
144
}
126
145
expected_calls = [
127
- unittest .mock .call (ResourceArn = trial_component_arn ),
146
+ unittest .mock .call (ResourceArn = training_job_arn ),
128
147
]
129
- pipeline_execution_arn_result = obj .pipeline_execution_arn ()
148
+ pipeline_execution_arn_result = context .pipeline_execution_arn ()
149
+
130
150
assert pipeline_execution_arn_result == "tag1"
131
151
assert expected_calls == sagemaker_session .sagemaker_client .list_tags .mock_calls
132
152
@@ -135,19 +155,77 @@ def test_no_pipeline_execution_arn(sagemaker_session):
135
155
trial_component_arn = (
136
156
"arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
137
157
)
138
- obj = lineage_trial_component . LineageTrialComponent (
139
- sagemaker_session , trial_component_name = "foo" , trial_component_arn = trial_component_arn
158
+ training_job_arn = (
159
+ "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
140
160
)
161
+ context = lineage_trial_component .LineageTrialComponent (
162
+ sagemaker_session ,
163
+ trial_component_name = "foo" ,
164
+ trial_component_arn = trial_component_arn ,
165
+ source = {
166
+ "SourceArn" : training_job_arn ,
167
+ "SourceType" : "SageMakerTrainingJob" ,
168
+ },
169
+ )
170
+ obj = {
171
+ "TrialComponentName" : "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job" ,
172
+ "TrialComponentArn" : trial_component_arn ,
173
+ "DisplayName" : "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job" ,
174
+ "Source" : {
175
+ "SourceArn" : training_job_arn ,
176
+ "SourceType" : "SageMakerTrainingJob" ,
177
+ },
178
+ }
179
+ sagemaker_session .sagemaker_client .describe_trial_component .return_value = obj
141
180
142
181
sagemaker_session .sagemaker_client .list_tags .return_value = {
143
182
"Tags" : [
144
183
{"Key" : "abcd" , "Value" : "efg" },
145
184
],
146
185
}
147
186
expected_calls = [
148
- unittest .mock .call (ResourceArn = trial_component_arn ),
187
+ unittest .mock .call (ResourceArn = training_job_arn ),
149
188
]
150
- pipeline_execution_arn_result = obj .pipeline_execution_arn ()
189
+ pipeline_execution_arn_result = context .pipeline_execution_arn ()
190
+ expected_result = None
191
+ assert pipeline_execution_arn_result == expected_result
192
+ assert expected_calls == sagemaker_session .sagemaker_client .list_tags .mock_calls
193
+
194
+
195
+ def test_no_source_arn_pipeline_execution_arn (sagemaker_session ):
196
+ trial_component_arn = (
197
+ "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
198
+ )
199
+ training_job_arn = (
200
+ "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
201
+ )
202
+ context = lineage_trial_component .LineageTrialComponent (
203
+ sagemaker_session ,
204
+ trial_component_name = "foo" ,
205
+ trial_component_arn = trial_component_arn ,
206
+ source = {
207
+ "SourceArn" : training_job_arn ,
208
+ "SourceType" : "SageMakerTrainingJob" ,
209
+ },
210
+ )
211
+ obj = {
212
+ "TrialComponentName" : "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job" ,
213
+ "TrialComponentArn" : trial_component_arn ,
214
+ "DisplayName" : "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job" ,
215
+ "Source" : {
216
+ "SourceArn" : None ,
217
+ "SourceType" : None ,
218
+ },
219
+ }
220
+ sagemaker_session .sagemaker_client .describe_trial_component .return_value = obj
221
+
222
+ sagemaker_session .sagemaker_client .list_tags .return_value = {
223
+ "Tags" : [
224
+ {"Key" : "abcd" , "Value" : "efg" },
225
+ ],
226
+ }
227
+ expected_calls = []
228
+ pipeline_execution_arn_result = context .pipeline_execution_arn ()
151
229
expected_result = None
152
230
assert pipeline_execution_arn_result == expected_result
153
231
assert expected_calls == sagemaker_session .sagemaker_client .list_tags .mock_calls
0 commit comments