Skip to content

Commit 5efbbd7

Browse files
committed
fix: double Run create on load_run
fixes aws#3673
1 parent 3a8a2e7 commit 5efbbd7

File tree

4 files changed

+45
-5
lines changed

4 files changed

+45
-5
lines changed

src/sagemaker/experiments/run.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,10 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
633633
Returns:
634634
str: The name of the Run object supplied by a user.
635635
"""
636-
return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1)
636+
# TODO: we should revert the lower casting once backend fix reaches prod
637+
return trial_component_name.replace(
638+
"{}{}".format(experiment_name.lower(), DELIMITER), "", 1
639+
)
637640

638641
@staticmethod
639642
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
@@ -869,6 +872,8 @@ def list_runs(
869872
Returns:
870873
list: A list of ``Run`` objects.
871874
"""
875+
876+
# all trial components retrieved by default
872877
tc_summaries = _TrialComponent.list(
873878
experiment_name=experiment_name,
874879
created_before=created_before,

tests/integ/sagemaker/experiments/test_run.py

+31
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,37 @@ def test_list(run_obj, sagemaker_session):
641641
assert run_tcs[0].experiment_name == run_obj.experiment_name
642642
assert run_tcs[0].experiment_config == run_obj.experiment_config
643643

644+
def test_list_twice(run_obj, sagemaker_session):
645+
tc1 = _TrialComponent.create(
646+
trial_component_name=f"non-run-tc1-{name()}",
647+
sagemaker_session=sagemaker_session,
648+
)
649+
tc2 = _TrialComponent.create(
650+
trial_component_name=f"non-run-tc2-{name()}",
651+
sagemaker_session=sagemaker_session,
652+
tags=TAGS,
653+
)
654+
run_obj._trial.add_trial_component(tc1)
655+
run_obj._trial.add_trial_component(tc2)
656+
657+
run_tcs = list_runs(
658+
experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
659+
)
660+
assert len(run_tcs) == 1
661+
assert run_tcs[0].run_name == run_obj.run_name
662+
assert run_tcs[0].experiment_name == run_obj.experiment_name
663+
assert run_tcs[0].experiment_config == run_obj.experiment_config
664+
665+
# note the experiment name used by run_obj is already mixed case and so
666+
# covers the mixed case experiment name double create issue
667+
run_tcs_second_result = list_runs(
668+
experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
669+
)
670+
assert len(run_tcs) == 1
671+
assert run_tcs_second_result[0].run_name == run_obj.run_name
672+
assert run_tcs_second_result[0].experiment_name == run_obj.experiment_name
673+
assert run_tcs_second_result[0].experiment_config == run_obj.experiment_config
674+
644675

645676
def _generate_estimator(
646677
exp_name,

tests/unit/sagemaker/experiments/helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
TEST_EXP_NAME = "my-experiment"
20+
TEST_EXP_NAME_MIXED_CASE = "My-eXpeRiMeNt"
2021
TEST_RUN_NAME = "my-run"
2122
TEST_EXP_DISPLAY_NAME = "my-experiment-display-name"
2223
TEST_RUN_DISPLAY_NAME = "my-run-display-name"

tests/unit/sagemaker/experiments/test_run.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
mock_trial_load_or_create_func,
4949
mock_tc_load_or_create_func,
5050
TEST_EXP_NAME,
51+
TEST_EXP_NAME_MIXED_CASE,
5152
TEST_RUN_NAME,
5253
TEST_EXP_DISPLAY_NAME,
5354
TEST_RUN_DISPLAY_NAME,
@@ -779,7 +780,9 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
779780
]
780781
mock_tc_list.return_value = [
781782
TrialComponentSummary(
782-
trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME),
783+
trial_component_name=Run._generate_trial_component_name(
784+
"A" + str(i), TEST_EXP_NAME_MIXED_CASE
785+
),
783786
trial_component_arn="b" + str(i),
784787
display_name="C" + str(i),
785788
source_arn="D" + str(i),
@@ -798,7 +801,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
798801
(
799802
_TrialComponent(
800803
trial_component_name=Run._generate_trial_component_name(
801-
"a" + str(i), TEST_EXP_NAME
804+
"a" + str(i), TEST_EXP_NAME_MIXED_CASE
802805
),
803806
trial_component_arn="b" + str(i),
804807
display_name="C" + str(i),
@@ -818,14 +821,14 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
818821
]
819822

820823
run_list = list_runs(
821-
experiment_name=TEST_EXP_NAME,
824+
experiment_name=TEST_EXP_NAME_MIXED_CASE,
822825
sort_by=SortByType.CREATION_TIME,
823826
sort_order=SortOrderType.ASCENDING,
824827
sagemaker_session=sagemaker_session,
825828
)
826829

827830
mock_tc_list.assert_called_once_with(
828-
experiment_name=TEST_EXP_NAME,
831+
experiment_name=TEST_EXP_NAME_MIXED_CASE,
829832
created_before=None,
830833
created_after=None,
831834
sort_by="CreationTime",

0 commit comments

Comments
 (0)