Skip to content

Commit 02a37a8

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
committed
fix: Fix run name uniqueness issue (aws#730)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 796d6d5 commit 02a37a8

File tree

12 files changed

+366
-587
lines changed

12 files changed

+366
-587
lines changed

.gitignore

-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,5 @@ env/
3131
**/tmp
3232
.python-version
3333
**/_repack_script_launcher.sh
34-
tests/data/experiment/docker/boto
35-
tests/data/experiment/docker/sagemaker-dev.tar.gz
3634
tests/data/**/_repack_model.py
3735
tests/data/experiment/resources/sagemaker-beta-1.0.tar.gz

src/sagemaker/experiments/experiment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def list_trials(self, created_before=None, created_after=None, sort_by=None, sor
192192
sagemaker_session=self.sagemaker_session,
193193
)
194194

195-
def delete_all(self, action):
195+
def _delete_all(self, action):
196196
"""Force to delete the experiment and associated trials, trial components.
197197
198198
Args:

src/sagemaker/experiments/run.py

+119-194
Large diffs are not rendered by default.

src/sagemaker/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -938,4 +938,4 @@ def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None
938938
)
939939
return experiment_config
940940

941-
return run_obj.experiment_config if run_obj else None
941+
return run_obj._experiment_config if run_obj else None

tests/integ/sagemaker/experiments/helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def cleanup_exp_resources(exp_names, sagemaker_session):
3939
finally:
4040
for exp_name in exp_names:
4141
exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
42-
exp.delete_all(action="--force")
42+
exp._delete_all(action="--force")

tests/integ/sagemaker/experiments/test_run.py

+136-73
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
from sagemaker.sklearn import SKLearn
2424
from sagemaker.utils import retry_with_backoff
2525
from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources
26-
from sagemaker.experiments.run import Run, RUN_NAME_BASE
26+
from sagemaker.experiments.run import (
27+
Run,
28+
RUN_NAME_BASE,
29+
DELIMITER,
30+
RUN_TC_TAG_KEY,
31+
RUN_TC_TAG_VALUE,
32+
)
2733
from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX
2834

2935

@@ -60,24 +66,20 @@ def lineage_artifact_path(tempdir):
6066
return file_path
6167

6268

63-
def test_local_run(
69+
file_artifact_name = f"file-artifact-{name()}"
70+
lineage_artifact_name = f"lineage-file-artifact-{name()}"
71+
metric_name = "test-x-step"
72+
73+
74+
def test_local_run_with_load_specifying_names(
6475
sagemaker_session, artifact_file_path, artifact_directory, lineage_artifact_path
6576
):
6677
exp_name = f"my-local-exp-{name()}"
67-
exp_name2 = f"{exp_name}-2"
68-
file_artifact_name = "file-artifact"
69-
lineage_artifact_name = "lineage-file-artifact"
70-
table_artifact_name = "TestTableTitle"
71-
metric_name = "test-x-step"
72-
73-
with cleanup_exp_resources(
74-
exp_names=[exp_name, exp_name2], sagemaker_session=sagemaker_session
75-
):
78+
with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
7679
# Run name is not provided, will create a new TC
7780
with Run.init(experiment_name=exp_name, sagemaker_session=sagemaker_session) as run1:
7881
run1_name = run1.run_name
79-
run1_exp_name = run1.experiment_name
80-
run1_trial_name = run1._trial.trial_name
82+
assert RUN_NAME_BASE in run1_name
8183

8284
run1.log_parameter("p1", 1.0)
8385
run1.log_parameter("p2", "p2-value")
@@ -86,73 +88,134 @@ def test_local_run(
8688
run1.log_artifact_file(file_path=artifact_file_path, name=file_artifact_name)
8789
run1.log_artifact_directory(directory=artifact_directory, is_output=False)
8890
run1.log_lineage_artifact(file_path=lineage_artifact_path, name=lineage_artifact_name)
89-
run1.log_table(
90-
title=table_artifact_name, values={"x": [1, 2, 3], "y": [4, 5, 6]}, is_output=False
91-
)
9291

9392
for i in range(_MetricsManager._BATCH_SIZE):
9493
run1.log_metric(name=metric_name, value=i, step=i)
9594

96-
assert RUN_NAME_BASE in run1_name
97-
98-
def validate_tc_artifact_association(is_output, expected_artifact_name):
99-
if is_output:
100-
# It's an output association from the tc
101-
response = sagemaker_session.sagemaker_client.list_associations(
102-
SourceArn=tc.trial_component_arn
103-
)
104-
else:
105-
# It's an input association to the tc
106-
response = sagemaker_session.sagemaker_client.list_associations(
107-
DestinationArn=tc.trial_component_arn
108-
)
109-
associations = response["AssociationSummaries"]
110-
111-
assert len(associations) == 1
112-
summary = associations[0]
113-
if is_output:
114-
assert summary["SourceArn"] == tc.trial_component_arn
115-
assert summary["DestinationName"] == expected_artifact_name
116-
else:
117-
assert summary["DestinationArn"] == tc.trial_component_arn
118-
assert summary["SourceName"] == expected_artifact_name
119-
120-
# Run name is passed from the name of an existing TC.
121-
# Meanwhile, the experiment_name is changed.
122-
# Should load TC from backend.
123-
with Run.init(
124-
experiment_name=exp_name2,
95+
with Run.load(
96+
experiment_name=exp_name,
12597
run_name=run1_name,
12698
sagemaker_session=sagemaker_session,
12799
) as run2:
128-
assert run1_exp_name != run2.experiment_name
129-
assert run1_trial_name != run2._trial.trial_name
130-
assert run1_name == run2.run_name
131-
132-
tc = run2._trial_component
133-
assert tc.parameters == {"p1": 1.0, "p2": "p2-value", "p3": 2.0, "p4": "p4-value"}
134-
135-
s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}"
136-
assert s3_prefix in tc.output_artifacts[file_artifact_name].value
137-
assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type
138-
assert s3_prefix in tc.input_artifacts["artifact_file1"].value
139-
assert "text/plain" == tc.input_artifacts["artifact_file1"].media_type
140-
assert s3_prefix in tc.input_artifacts["artifact_file2"].value
141-
assert "text/plain" == tc.input_artifacts["artifact_file2"].media_type
142-
143-
assert len(tc.metrics) == 1
144-
metric_summary = tc.metrics[0]
145-
assert metric_summary.metric_name == metric_name
146-
assert metric_summary.max == 9.0
147-
assert metric_summary.min == 0.0
100+
assert run2.run_name == run1_name
101+
assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run1_name}"
102+
_check_run_from_local_end_result(
103+
sagemaker_session=sagemaker_session, tc=run2._trial_component
104+
)
105+
148106

149-
validate_tc_artifact_association(
150-
is_output=True,
151-
expected_artifact_name=lineage_artifact_name,
107+
def _check_run_from_local_end_result(sagemaker_session, tc):
108+
def validate_tc_artifact_association(is_output, expected_artifact_name):
109+
if is_output:
110+
# It's an output association from the tc
111+
response = sagemaker_session.sagemaker_client.list_associations(
112+
SourceArn=tc.trial_component_arn
113+
)
114+
else:
115+
# It's an input association to the tc
116+
response = sagemaker_session.sagemaker_client.list_associations(
117+
DestinationArn=tc.trial_component_arn
152118
)
153-
validate_tc_artifact_association(
154-
is_output=False,
155-
expected_artifact_name=table_artifact_name,
119+
associations = response["AssociationSummaries"]
120+
121+
assert len(associations) == 1
122+
summary = associations[0]
123+
if is_output:
124+
assert summary["SourceArn"] == tc.trial_component_arn
125+
assert summary["DestinationName"] == expected_artifact_name
126+
else:
127+
assert summary["DestinationArn"] == tc.trial_component_arn
128+
assert summary["SourceName"] == expected_artifact_name
129+
130+
assert tc.parameters == {"p1": 1.0, "p2": "p2-value", "p3": 2.0, "p4": "p4-value"}
131+
132+
s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}"
133+
assert s3_prefix in tc.output_artifacts[file_artifact_name].value
134+
assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type
135+
assert s3_prefix in tc.input_artifacts["artifact_file1"].value
136+
assert "text/plain" == tc.input_artifacts["artifact_file1"].media_type
137+
assert s3_prefix in tc.input_artifacts["artifact_file2"].value
138+
assert "text/plain" == tc.input_artifacts["artifact_file2"].media_type
139+
140+
assert len(tc.metrics) == 1
141+
metric_summary = tc.metrics[0]
142+
assert metric_summary.metric_name == metric_name
143+
assert metric_summary.max == 9.0
144+
assert metric_summary.min == 0.0
145+
146+
validate_tc_artifact_association(
147+
is_output=True,
148+
expected_artifact_name=lineage_artifact_name,
149+
)
150+
151+
152+
def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session):
153+
exp_name1 = f"my-two-local-exp1-{name()}"
154+
exp_name2 = f"my-two-local-exp2-{name()}"
155+
run_name = "test-run"
156+
with cleanup_exp_resources(
157+
exp_names=[exp_name1, exp_name2], sagemaker_session=sagemaker_session
158+
):
159+
# Run name is not provided, will create a new TC
160+
with Run.init(
161+
experiment_name=exp_name1, run_name=run_name, sagemaker_session=sagemaker_session
162+
) as run1:
163+
pass
164+
with Run.init(
165+
experiment_name=exp_name2, run_name=run_name, sagemaker_session=sagemaker_session
166+
) as run2:
167+
pass
168+
169+
assert run1.experiment_name != run2.experiment_name
170+
assert run1.run_name == run2.run_name
171+
assert (
172+
run1._trial_component.trial_component_name != run2._trial_component.trial_component_name
173+
)
174+
assert run1._trial_component.trial_component_name == f"{exp_name1}{DELIMITER}{run_name}"
175+
assert run2._trial_component.trial_component_name == f"{exp_name2}{DELIMITER}{run_name}"
176+
177+
178+
@pytest.mark.parametrize(
179+
"input_names",
180+
[
181+
(f"my-local-exp-{name()}", "test-run", None), # both have delimiter -
182+
("my-test-1", "my-test-1", None), # exp_name equals run_name
183+
("my-test-3", "my-test-3-run", None), # <exp_name><delimiter> is subset of run_name
184+
("x" * 59, "test-run", None), # long exp_name
185+
("test-exp", "y" * 59, None), # long run_name
186+
("x" * 59, "y" * 59, None), # long exp_name and run_name
187+
("my-test4", "test-run", "run-display-name-test"), # with supplied display name
188+
],
189+
)
190+
def test_run_name_vs_trial_component_name_edge_cases(
191+
sagemaker_session, artifact_file_path, artifact_directory, lineage_artifact_path, input_names
192+
):
193+
exp_name, run_name, run_display_name = input_names
194+
with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
195+
with Run.init(
196+
experiment_name=exp_name,
197+
sagemaker_session=sagemaker_session,
198+
run_name=run_name,
199+
run_display_name=run_display_name,
200+
) as run1:
201+
assert not run1._experiment.tags
202+
assert not run1._trial.tags
203+
tags = run1._trial_component.tags
204+
assert len(tags) == 1
205+
assert tags[0]["Key"] == RUN_TC_TAG_KEY
206+
assert tags[0]["Value"] == RUN_TC_TAG_VALUE
207+
208+
with Run.load(
209+
experiment_name=exp_name,
210+
run_name=run_name,
211+
sagemaker_session=sagemaker_session,
212+
) as run2:
213+
assert run2.experiment_name == exp_name
214+
assert run2.run_name == run_name
215+
assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run_name}"
216+
assert run2._trial_component.display_name in (
217+
run_display_name,
218+
run2._trial_component.trial_component_name,
156219
)
157220

158221

@@ -201,7 +264,7 @@ def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, j
201264
)
202265
estimator.fit(
203266
job_name=f"train-job-{name()}",
204-
experiment_config=run.experiment_config,
267+
experiment_config=run._experiment_config,
205268
wait=True, # wait the training job to finish
206269
logs="None", # set to "All" to display logs fetched from the training job
207270
)
@@ -248,7 +311,7 @@ def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, j
248311
)
249312
estimator.fit(
250313
job_name=f"train-job-{name()}",
251-
experiment_config=run.experiment_config,
314+
experiment_config=run._experiment_config,
252315
wait=True, # wait the training job to finish
253316
logs="None", # set to "All" to display logs fetched from the training job
254317
)

tests/unit/sagemaker/experiments/conftest.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
mock_tc_load_or_create_func,
2525
mock_trial_load_or_create_func,
2626
TEST_EXP_NAME,
27-
TEST_RUN_GRP_NAME,
2827
)
2928

3029

@@ -74,14 +73,13 @@ def run_obj(sagemaker_session):
7473
):
7574
run = Run.init(
7675
experiment_name=TEST_EXP_NAME,
77-
run_group_name=TEST_RUN_GRP_NAME,
7876
sagemaker_session=sagemaker_session,
7977
)
8078
run._artifact_uploader = Mock()
8179
run._lineage_artifact_tracker = Mock()
8280
run._metrics_manager = Mock()
8381

8482
assert run.run_name.startswith(RUN_NAME_BASE)
85-
assert run.run_group_name == TEST_RUN_GRP_NAME
83+
assert run.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
8684

8785
return run

tests/unit/sagemaker/experiments/helpers.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818

1919
TEST_EXP_NAME = "my-experiment"
20-
TEST_RUN_GRP_NAME = "my-run-group"
2120
TEST_RUN_NAME = "my-run"
2221

2322

tests/unit/sagemaker/experiments/test_experiment.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def test_list_trials_call_args(sagemaker_session):
210210
def test_delete_all_with_incorrect_action_name(sagemaker_session):
211211
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
212212
with pytest.raises(ValueError) as err:
213-
obj.delete_all(action="abc")
213+
obj._delete_all(action="abc")
214214

215215
assert "Must confirm with string '--force'" in str(err)
216216

@@ -278,7 +278,7 @@ def test_delete_all(sagemaker_session):
278278
client.delete_trial.return_value = {}
279279
client.delete_experiment.return_value = {}
280280

281-
obj.delete_all(action="--force")
281+
obj._delete_all(action="--force")
282282

283283
client.delete_experiment.assert_called_with(ExperimentName="foo")
284284

@@ -301,6 +301,6 @@ def test_delete_all_fail(sagemaker_session):
301301
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
302302
sagemaker_session.sagemaker_client.list_trials.side_effect = Exception
303303
with pytest.raises(Exception) as e:
304-
obj.delete_all(action="--force")
304+
obj._delete_all(action="--force")
305305

306306
assert str(e.value) == "Failed to delete, please try again."

0 commit comments

Comments
 (0)