Skip to content

fix: double Run create on load_run #3821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,10 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
Returns:
str: The name of the Run object supplied by a user.
"""
return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1)
# TODO: we should revert the lower casting once backend fix reaches prod
return trial_component_name.replace(
"{}{}".format(experiment_name.lower(), DELIMITER), "", 1
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! So I assume this line would fix the issue of second run created? And will fix the list_runs issue #3673?

Can we add a unit test which does list_runs twice in case to validate in case any future changes break it again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will do


@staticmethod
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
Expand Down Expand Up @@ -869,6 +872,8 @@ def list_runs(
Returns:
list: A list of ``Run`` objects.
"""

# all trial components retrieved by default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the _TrialComponent.list can retrieve all trial components?

return super(_TrialComponent, cls)._list(
"list_trial_components",
_api_types.TrialComponentSummary.from_boto,
"TrialComponentSummaries",
source_arn=source_arn,
created_before=created_before,
created_after=created_after,
sort_by=sort_by,
sort_order=sort_order,
sagemaker_session=sagemaker_session,
trial_name=trial_name,
experiment_name=experiment_name,
max_results=max_results,
next_token=next_token,
)

Or the list trial component API can retrieve all TCs in one call?

If the next_token is not causing any issues, I'd suggest to keep it to avoid a breaking change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all trial components are retrieved by default:

https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/apiutils/_base_types.py#L118-L147

The problem with this parameter is where is the customer going to get the next_token to supply to method in the first place? The only way to get next_token is to call ListTrialComponents in which case they have no reason to use this method.

Also this method doesn't return a NextToken.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to remove since it is not providing useful behavior as it currently is. In the future we could re-add it but will need to return a NextToken so the customer can supply it in their own page-loop behavior.

tc_summaries = _TrialComponent.list(
experiment_name=experiment_name,
created_before=created_before,
Expand Down
32 changes: 32 additions & 0 deletions tests/integ/sagemaker/experiments/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,38 @@ def test_list(run_obj, sagemaker_session):
assert run_tcs[0].experiment_config == run_obj.experiment_config


def test_list_twice(run_obj, sagemaker_session):
tc1 = _TrialComponent.create(
trial_component_name=f"non-run-tc1-{name()}",
sagemaker_session=sagemaker_session,
)
tc2 = _TrialComponent.create(
trial_component_name=f"non-run-tc2-{name()}",
sagemaker_session=sagemaker_session,
tags=TAGS,
)
run_obj._trial.add_trial_component(tc1)
run_obj._trial.add_trial_component(tc2)

run_tcs = list_runs(
experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
)
assert len(run_tcs) == 1
assert run_tcs[0].run_name == run_obj.run_name
assert run_tcs[0].experiment_name == run_obj.experiment_name
assert run_tcs[0].experiment_config == run_obj.experiment_config

# note the experiment name used by run_obj is already mixed case and so
# covers the mixed case experiment name double create issue
run_tcs_second_result = list_runs(
experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
)
assert len(run_tcs) == 1
assert run_tcs_second_result[0].run_name == run_obj.run_name
assert run_tcs_second_result[0].experiment_name == run_obj.experiment_name
assert run_tcs_second_result[0].experiment_config == run_obj.experiment_config


def _generate_estimator(
exp_name,
sdk_tar,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/experiments/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


TEST_EXP_NAME = "my-experiment"
TEST_EXP_NAME_MIXED_CASE = "My-eXpeRiMeNt"
TEST_RUN_NAME = "my-run"
TEST_EXP_DISPLAY_NAME = "my-experiment-display-name"
TEST_RUN_DISPLAY_NAME = "my-run-display-name"
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/sagemaker/experiments/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
mock_trial_load_or_create_func,
mock_tc_load_or_create_func,
TEST_EXP_NAME,
TEST_EXP_NAME_MIXED_CASE,
TEST_RUN_NAME,
TEST_EXP_DISPLAY_NAME,
TEST_RUN_DISPLAY_NAME,
Expand Down Expand Up @@ -779,7 +780,9 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
]
mock_tc_list.return_value = [
TrialComponentSummary(
trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME),
trial_component_name=Run._generate_trial_component_name(
"A" + str(i), TEST_EXP_NAME_MIXED_CASE
),
trial_component_arn="b" + str(i),
display_name="C" + str(i),
source_arn="D" + str(i),
Expand All @@ -798,7 +801,7 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
(
_TrialComponent(
trial_component_name=Run._generate_trial_component_name(
"a" + str(i), TEST_EXP_NAME
"a" + str(i), TEST_EXP_NAME_MIXED_CASE
),
trial_component_arn="b" + str(i),
display_name="C" + str(i),
Expand All @@ -818,14 +821,14 @@ def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_ses
]

run_list = list_runs(
experiment_name=TEST_EXP_NAME,
experiment_name=TEST_EXP_NAME_MIXED_CASE,
sort_by=SortByType.CREATION_TIME,
sort_order=SortOrderType.ASCENDING,
sagemaker_session=sagemaker_session,
)

mock_tc_list.assert_called_once_with(
experiment_name=TEST_EXP_NAME,
experiment_name=TEST_EXP_NAME_MIXED_CASE,
created_before=None,
created_after=None,
sort_by="CreationTime",
Expand Down