diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index 6202de858c..94d07a9655 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -633,7 +633,10 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s Returns: str: The name of the Run object supplied by a user. """ - return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1) + # TODO: we should revert the lower casting once backend fix reaches prod + return trial_component_name.replace( + "{}{}".format(experiment_name.lower(), DELIMITER), "", 1 + ) @staticmethod def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: @@ -869,6 +872,8 @@ def list_runs( Returns: list: A list of ``Run`` objects. """ + + # all trial components retrieved by default tc_summaries = _TrialComponent.list( experiment_name=experiment_name, created_before=created_before, diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py index 40738e9360..96fc632ad7 100644 --- a/tests/integ/sagemaker/experiments/test_run.py +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -642,6 +642,38 @@ def test_list(run_obj, sagemaker_session): assert run_tcs[0].experiment_config == run_obj.experiment_config +def test_list_twice(run_obj, sagemaker_session): + tc1 = _TrialComponent.create( + trial_component_name=f"non-run-tc1-{name()}", + sagemaker_session=sagemaker_session, + ) + tc2 = _TrialComponent.create( + trial_component_name=f"non-run-tc2-{name()}", + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + run_obj._trial.add_trial_component(tc1) + run_obj._trial.add_trial_component(tc2) + + run_tcs = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs[0].run_name == run_obj.run_name + assert run_tcs[0].experiment_name == run_obj.experiment_name + assert run_tcs[0].experiment_config == run_obj.experiment_config + + # note the experiment name used by run_obj is already mixed case and so + # covers the mixed case experiment name double create issue + run_tcs_second_result = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs_second_result[0].run_name == run_obj.run_name + assert run_tcs_second_result[0].experiment_name == run_obj.experiment_name + assert run_tcs_second_result[0].experiment_config == run_obj.experiment_config + + def _generate_estimator( exp_name, sdk_tar, diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py index 0fec9f7fc3..d560462def 100644 --- a/tests/unit/sagemaker/experiments/helpers.py +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -17,6 +17,7 @@ TEST_EXP_NAME = "my-experiment" +TEST_EXP_NAME_MIXED_CASE = "My-eXpeRiMeNt" TEST_RUN_NAME = "my-run" TEST_EXP_DISPLAY_NAME = "my-experiment-display-name" TEST_RUN_DISPLAY_NAME = "my-run-display-name" diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py index 7f54fe8d6f..3820d7e4f6 100644 --- a/tests/unit/sagemaker/experiments/test_run.py +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -48,6 +48,7 @@ mock_trial_load_or_create_func, mock_tc_load_or_create_func, TEST_EXP_NAME, + TEST_EXP_NAME_MIXED_CASE, TEST_RUN_NAME, TEST_EXP_DISPLAY_NAME, TEST_RUN_DISPLAY_NAME, @@ -779,7 +780,9 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses ] mock_tc_list.return_value = [ TrialComponentSummary( - trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME), + trial_component_name=Run._generate_trial_component_name( + "A" + str(i), TEST_EXP_NAME_MIXED_CASE + ), trial_component_arn="b" + str(i), display_name="C" + str(i), source_arn="D" + str(i), @@ -798,7 +801,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses ( _TrialComponent( trial_component_name=Run._generate_trial_component_name( - "a" + str(i), TEST_EXP_NAME + "a" + str(i), TEST_EXP_NAME_MIXED_CASE ), trial_component_arn="b" + str(i), display_name="C" + str(i), @@ -818,14 +821,14 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses ] run_list = list_runs( - experiment_name=TEST_EXP_NAME, + experiment_name=TEST_EXP_NAME_MIXED_CASE, sort_by=SortByType.CREATION_TIME, sort_order=SortOrderType.ASCENDING, sagemaker_session=sagemaker_session, ) mock_tc_list.assert_called_once_with( - experiment_name=TEST_EXP_NAME, + experiment_name=TEST_EXP_NAME_MIXED_CASE, created_before=None, created_after=None, sort_by="CreationTime",