55
55
TEST_RUN_DISPLAY_NAME ,
56
56
TEST_ARTIFACT_BUCKET ,
57
57
TEST_ARTIFACT_PREFIX ,
58
+ TEST_TAGS
58
59
)
59
60
60
61
@@ -155,24 +156,22 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
155
156
156
157
157
158
@pytest .mark .parametrize (
158
- ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" ),
159
+ ("kwargs" , "expected_artifact_bucket" , "expected_artifact_prefix" , "expected_tags" ),
159
160
[
160
- ({}, None , _DEFAULT_ARTIFACT_PREFIX ),
161
+ ({}, None , _DEFAULT_ARTIFACT_PREFIX , None ),
161
162
(
162
163
{
163
164
"artifact_bucket" : TEST_ARTIFACT_BUCKET ,
164
165
"artifact_prefix" : TEST_ARTIFACT_PREFIX ,
166
+ "tags" : TEST_TAGS
165
167
},
166
168
TEST_ARTIFACT_BUCKET ,
167
169
TEST_ARTIFACT_PREFIX ,
170
+ TEST_TAGS
168
171
),
169
172
],
170
173
)
171
174
@patch .object (_TrialComponent , "save" , MagicMock (return_value = None ))
172
- @patch (
173
- "sagemaker.experiments.run.Experiment._load_or_create" ,
174
- MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME )),
175
- )
176
175
@patch (
177
176
"sagemaker.experiments.run._Trial._load_or_create" ,
178
177
MagicMock (side_effect = mock_trial_load_or_create_func ),
@@ -189,6 +188,7 @@ def test_run_load_no_run_name_and_in_train_job(
189
188
kwargs ,
190
189
expected_artifact_bucket ,
191
190
expected_artifact_prefix ,
191
+ expected_tags
192
192
):
193
193
client = sagemaker_session .sagemaker_client
194
194
job_name = "my-train-job"
@@ -220,19 +220,22 @@ def test_run_load_no_run_name_and_in_train_job(
220
220
}
221
221
]
222
222
}
223
- with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
224
- assert run_obj ._in_load
225
- assert not run_obj ._inside_init_context
226
- assert run_obj ._inside_load_context
227
- assert run_obj .run_name == TEST_RUN_NAME
228
- assert run_obj ._trial_component .trial_component_name == expected_tc_name
229
- assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
230
- assert run_obj ._trial
231
- assert run_obj .experiment_name == TEST_EXP_NAME
232
- assert run_obj ._experiment
233
- assert run_obj .experiment_config == exp_config
234
- assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
235
- assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
223
+ expmock = MagicMock (return_value = Experiment (experiment_name = TEST_EXP_NAME ,tags = expected_tags ))
224
+ with patch ("sagemaker.experiments.run.Experiment._load_or_create" , expmock ):
225
+ with load_run (sagemaker_session = sagemaker_session , ** kwargs ) as run_obj :
226
+ assert run_obj ._in_load
227
+ assert not run_obj ._inside_init_context
228
+ assert run_obj ._inside_load_context
229
+ assert run_obj .run_name == TEST_RUN_NAME
230
+ assert run_obj ._trial_component .trial_component_name == expected_tc_name
231
+ assert run_obj .run_group_name == Run ._generate_trial_name (TEST_EXP_NAME )
232
+ assert run_obj ._trial
233
+ assert run_obj .experiment_name == TEST_EXP_NAME
234
+ assert run_obj ._experiment
235
+ assert run_obj .experiment_config == exp_config
236
+ assert run_obj ._artifact_uploader .artifact_bucket == expected_artifact_bucket
237
+ assert run_obj ._artifact_uploader .artifact_prefix == expected_artifact_prefix
238
+ assert run_obj ._experiment .tags == expected_tags
236
239
237
240
client .describe_training_job .assert_called_once_with (TrainingJobName = job_name )
238
241
run_obj ._trial .add_trial_component .assert_not_called ()
0 commit comments