Skip to content

Commit 5a9cd8c

Browse files
committed
fix: double Run create on load_run
fixes #3673
1 parent 3a8a2e7 commit 5a9cd8c

File tree

4 files changed

+46
-5
lines changed

4 files changed

+46
-5
lines changed

src/sagemaker/experiments/run.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,10 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
633633
Returns:
634634
str: The name of the Run object supplied by a user.
635635
"""
636-
return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1)
636+
# TODO: we should revert the lower casting once backend fix reaches prod
637+
return trial_component_name.replace(
638+
"{}{}".format(experiment_name.lower(), DELIMITER), "", 1
639+
)
637640

638641
@staticmethod
639642
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
@@ -869,6 +872,8 @@ def list_runs(
869872
Returns:
870873
list: A list of ``Run`` objects.
871874
"""
875+
876+
# all trial components retrieved by default
872877
tc_summaries = _TrialComponent.list(
873878
experiment_name=experiment_name,
874879
created_before=created_before,

tests/integ/sagemaker/experiments/test_run.py

+32
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,38 @@ def test_list(run_obj, sagemaker_session):
642642
assert run_tcs[0].experiment_config == run_obj.experiment_config
643643

644644

645+
def test_list_twice(run_obj, sagemaker_session):
646+
tc1 = _TrialComponent.create(
647+
trial_component_name=f"non-run-tc1-{name()}",
648+
sagemaker_session=sagemaker_session,
649+
)
650+
tc2 = _TrialComponent.create(
651+
trial_component_name=f"non-run-tc2-{name()}",
652+
sagemaker_session=sagemaker_session,
653+
tags=TAGS,
654+
)
655+
run_obj._trial.add_trial_component(tc1)
656+
run_obj._trial.add_trial_component(tc2)
657+
658+
run_tcs = list_runs(
659+
experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
660+
)
661+
assert len(run_tcs) == 1
662+
assert run_tcs[0].run_name == run_obj.run_name
663+
assert run_tcs[0].experiment_name == run_obj.experiment_name
664+
assert run_tcs[0].experiment_config == run_obj.experiment_config
665+
666+
# note the experiment name used by run_obj is already mixed case and so
667+
# covers the mixed case experiment name double create issue
668+
run_tcs_second_result = list_runs(
669+
experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
670+
)
671+
assert len(run_tcs) == 1
672+
assert run_tcs_second_result[0].run_name == run_obj.run_name
673+
assert run_tcs_second_result[0].experiment_name == run_obj.experiment_name
674+
assert run_tcs_second_result[0].experiment_config == run_obj.experiment_config
675+
676+
645677
def _generate_estimator(
646678
exp_name,
647679
sdk_tar,

tests/unit/sagemaker/experiments/helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
TEST_EXP_NAME = "my-experiment"
20+
TEST_EXP_NAME_MIXED_CASE = "My-eXpeRiMeNt"
2021
TEST_RUN_NAME = "my-run"
2122
TEST_EXP_DISPLAY_NAME = "my-experiment-display-name"
2223
TEST_RUN_DISPLAY_NAME = "my-run-display-name"

tests/unit/sagemaker/experiments/test_run.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
mock_trial_load_or_create_func,
4949
mock_tc_load_or_create_func,
5050
TEST_EXP_NAME,
51+
TEST_EXP_NAME_MIXED_CASE,
5152
TEST_RUN_NAME,
5253
TEST_EXP_DISPLAY_NAME,
5354
TEST_RUN_DISPLAY_NAME,
@@ -779,7 +780,9 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
779780
]
780781
mock_tc_list.return_value = [
781782
TrialComponentSummary(
782-
trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME),
783+
trial_component_name=Run._generate_trial_component_name(
784+
"A" + str(i), TEST_EXP_NAME_MIXED_CASE
785+
),
783786
trial_component_arn="b" + str(i),
784787
display_name="C" + str(i),
785788
source_arn="D" + str(i),
@@ -798,7 +801,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
798801
(
799802
_TrialComponent(
800803
trial_component_name=Run._generate_trial_component_name(
801-
"a" + str(i), TEST_EXP_NAME
804+
"a" + str(i), TEST_EXP_NAME_MIXED_CASE
802805
),
803806
trial_component_arn="b" + str(i),
804807
display_name="C" + str(i),
@@ -818,14 +821,14 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
818821
]
819822

820823
run_list = list_runs(
821-
experiment_name=TEST_EXP_NAME,
824+
experiment_name=TEST_EXP_NAME_MIXED_CASE,
822825
sort_by=SortByType.CREATION_TIME,
823826
sort_order=SortOrderType.ASCENDING,
824827
sagemaker_session=sagemaker_session,
825828
)
826829

827830
mock_tc_list.assert_called_once_with(
828-
experiment_name=TEST_EXP_NAME,
831+
experiment_name=TEST_EXP_NAME_MIXED_CASE,
829832
created_before=None,
830833
created_after=None,
831834
sort_by="CreationTime",

0 commit comments

Comments
 (0)