Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fe231bf

Browse files
committedFeb 17, 2023
fix: add tag fixes for pipelines and experiments
1 parent 607f85a commit fe231bf

File tree

10 files changed

+149
-55
lines changed

10 files changed

+149
-55
lines changed
 

‎src/sagemaker/experiments/experiment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments.trial import _Trial
2022
from sagemaker.experiments.trial_component import _TrialComponent
@@ -154,17 +156,21 @@ def _load_or_create(
154156
Returns:
155157
experiments.experiment._Experiment: A SageMaker `_Experiment` object
156158
"""
157-
sagemaker_client = sagemaker_session.sagemaker_client
158159
try:
159-
experiment = _Experiment.load(experiment_name, sagemaker_session)
160-
except sagemaker_client.exceptions.ResourceNotFound:
161160
experiment = _Experiment.create(
162161
experiment_name=experiment_name,
163162
display_name=display_name,
164163
description=description,
165164
tags=tags,
166165
sagemaker_session=sagemaker_session,
167166
)
167+
except ClientError as ce:
168+
error_code = ce.response["Error"]["Code"]
169+
error_message = ce.response["Error"]["Message"]
170+
if not (error_code == "ValidationException" and "already exists" in error_message):
171+
raise ce
172+
# already exists
173+
experiment = _Experiment.load(experiment_name, sagemaker_session)
168174
return experiment
169175

170176
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):

‎src/sagemaker/experiments/trial.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Contains the Trial class."""
1414
from __future__ import absolute_import
1515

16+
from botocore.exceptions import ClientError
17+
1618
from sagemaker.apiutils import _base_types
1719
from sagemaker.experiments import _api_types
1820
from sagemaker.experiments.trial_component import _TrialComponent
@@ -268,8 +270,20 @@ def _load_or_create(
268270
Returns:
269271
experiments.trial._Trial: A SageMaker `_Trial` object
270272
"""
271-
sagemaker_client = sagemaker_session.sagemaker_client
272273
try:
274+
trial = _Trial.create(
275+
experiment_name=experiment_name,
276+
trial_name=trial_name,
277+
display_name=display_name,
278+
tags=tags,
279+
sagemaker_session=sagemaker_session,
280+
)
281+
except ClientError as ce:
282+
error_code = ce.response["Error"]["Code"]
283+
error_message = ce.response["Error"]["Message"]
284+
if not (error_code == "ValidationException" and "already exists" in error_message):
285+
raise ce
286+
# already exists
273287
trial = _Trial.load(trial_name, sagemaker_session)
274288
if trial.experiment_name != experiment_name: # pylint: disable=no-member
275289
raise ValueError(
@@ -278,12 +292,4 @@ def _load_or_create(
278292
trial.experiment_name # pylint: disable=no-member
279293
)
280294
)
281-
except sagemaker_client.exceptions.ResourceNotFound:
282-
trial = _Trial.create(
283-
experiment_name=experiment_name,
284-
trial_name=trial_name,
285-
display_name=display_name,
286-
tags=tags,
287-
sagemaker_session=sagemaker_session,
288-
)
289295
return trial

‎src/sagemaker/experiments/trial_component.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments import _api_types
2022
from sagemaker.experiments._api_types import TrialComponentSearchResult
@@ -326,16 +328,20 @@ def _load_or_create(
326328
experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
327329
bool: A boolean variable indicating whether the trail component already exists
328330
"""
329-
sagemaker_client = sagemaker_session.sagemaker_client
330331
is_existed = False
331332
try:
332-
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
333-
is_existed = True
334-
except sagemaker_client.exceptions.ResourceNotFound:
335333
run_tc = _TrialComponent.create(
336334
trial_component_name=trial_component_name,
337335
display_name=display_name,
338336
tags=tags,
339337
sagemaker_session=sagemaker_session,
340338
)
339+
except ClientError as ce:
340+
error_code = ce.response["Error"]["Code"]
341+
error_message = ce.response["Error"]["Message"]
342+
if not (error_code == "ValidationException" and "already exists" in error_message):
343+
raise ce
344+
# already exists
345+
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
346+
is_existed = True
341347
return run_tc, is_existed

‎src/sagemaker/session.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5448,8 +5448,9 @@ def _deployment_entity_exists(describe_fn):
54485448

54495449

54505450
def _create_resource(create_fn):
5451-
"""Call create function and while doing so accepts/passes the resource already exists exception.
5452-
Throws an exception if any exception other than resource already exists.
5451+
"""Call create function and accepts/pass when resource already exists.
5452+
5453+
This is a helper function to use an existing resource if found when creating.
54535454
54545455
Args:
54555456
create_fn: Create resource function.
@@ -5823,9 +5824,9 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
58235824
job_desc, finished = callable_fn(job_desc)
58245825
except botocore.exceptions.ClientError as err:
58255826
# For initial 5 mins we accept/pass AccessDeniedException.
5826-
# The reason is to await tag propagation to avoid false AccessDenied claims for an access
5827-
# policy based on resource tags, The caveat here is for true AccessDenied cases the routine
5828-
# will fail after 5 mins
5827+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
5828+
# access policy based on resource tags, The caveat here is for true AccessDenied
5829+
# cases the routine will fail after 5 mins
58295830
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
58305831
LOGGER.warning(
58315832
"Received AccessDeniedException. This could mean the IAM role does not "
@@ -5834,8 +5835,6 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
58345835
"continuing to wait for tag propagation.."
58355836
)
58365837
continue
5837-
else:
5838-
raise err
58395838
return job_desc
58405839

58415840

@@ -5861,8 +5860,6 @@ def _wait_until(callable_fn, poll=5):
58615860
"continuing to wait for tag propagation.."
58625861
)
58635862
continue
5864-
else:
5865-
raise err
58665863
return result
58675864

58685865

‎src/sagemaker/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,17 @@ def retries(
604604
)
605605

606606

607-
def retry_with_backoff(callable_func, num_attempts=8):
607+
def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
608608
"""Retry with backoff until maximum attempts are reached
609609
610610
Args:
611611
callable_func (callable): The callable function to retry.
612-
num_attempts (int): The maximum number of attempts to retry.
612+
num_attempts (int): The maximum number of attempts to retry.(Default: 8)
613+
botocore_client_error_code (str): The specific Botocore ClientError exception error code
614+
on which to retry on.
615+
If provided other exceptions will be raised directly w/o retry.
616+
If not provided, retry on any exception.
617+
(Default: None)
613618
"""
614619
if num_attempts < 1:
615620
raise ValueError(
@@ -619,7 +624,15 @@ def retry_with_backoff(callable_func, num_attempts=8):
619624
try:
620625
return callable_func()
621626
except Exception as ex: # pylint: disable=broad-except
622-
if i == num_attempts - 1:
627+
if not botocore_client_error_code or (
628+
botocore_client_error_code
629+
and isinstance(ex) == botocore.exceptions.ClientError
630+
and ex.response["Error"]["Code"]
631+
== botocore_client_error_code # pylint: disable=no-member
632+
):
633+
if i == num_attempts - 1:
634+
raise ex
635+
else:
623636
raise ex
624637
logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
625638
time.sleep(2**i)

‎src/sagemaker/workflow/pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker import s3
2727
from sagemaker._studio import _append_project_tags
2828
from sagemaker.session import Session
29+
from sagemaker.utils import retry_with_backoff
2930
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
3031
from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep
3132
from sagemaker.workflow.entities import (
@@ -306,7 +307,12 @@ def start(
306307
update_args(kwargs, PipelineParameters=parameters)
307308
return self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
308309
update_args(kwargs, PipelineParameters=format_start_parameters(parameters))
309-
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
310+
311+
# retry on AccessDeniedException to cover case of tag propagation delay
312+
response = retry_with_backoff(
313+
lambda: self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs),
314+
botocore_client_error_code="AccessDeniedException",
315+
)
310316
return _PipelineExecution(
311317
arn=response["PipelineExecutionArn"],
312318
sagemaker_session=self.sagemaker_session,

‎tests/unit/sagemaker/experiments/test_experiment.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import botocore.exceptions
1516
import pytest
1617
import unittest.mock
1718
import datetime
@@ -78,37 +79,48 @@ def test_delete(sagemaker_session):
7879

7980

8081
@patch("sagemaker.experiments.experiment._Experiment.load")
81-
def test_load_or_create_when_exist(mock_load, sagemaker_session):
82+
@patch("sagemaker.experiments.experiment._Experiment.create")
83+
def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
8284
exp_name = "exp_name"
85+
exists_error = botocore.exceptions.ClientError(
86+
error_response={
87+
"Error": {
88+
"Code": "ValidationException",
89+
"Message": "Experiment with name (experiment-xyz) already exists.",
90+
}
91+
},
92+
operation_name="foo",
93+
)
94+
mock_create.side_effect = exists_error
8395
experiment._Experiment._load_or_create(
8496
experiment_name=exp_name, sagemaker_session=sagemaker_session
8597
)
98+
mock_create.assert_called_once_with(
99+
experiment_name=exp_name,
100+
display_name=None,
101+
description=None,
102+
tags=None,
103+
sagemaker_session=sagemaker_session,
104+
)
86105
mock_load.assert_called_once_with(exp_name, sagemaker_session)
87106

88107

89108
@patch("sagemaker.experiments.experiment._Experiment.load")
90109
@patch("sagemaker.experiments.experiment._Experiment.create")
91110
def test_load_or_create_when_not_exist(mock_create, mock_load):
92111
sagemaker_session = Session()
93-
client = sagemaker_session.sagemaker_client
94112
exp_name = "exp_name"
95-
not_found_err = client.exceptions.ResourceNotFound(
96-
error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}},
97-
operation_name="foo",
98-
)
99-
mock_load.side_effect = not_found_err
100-
101113
experiment._Experiment._load_or_create(
102114
experiment_name=exp_name, sagemaker_session=sagemaker_session
103115
)
104-
105116
mock_create.assert_called_once_with(
106117
experiment_name=exp_name,
107118
display_name=None,
108119
description=None,
109120
tags=None,
110121
sagemaker_session=sagemaker_session,
111122
)
123+
mock_load.assert_not_called()
112124

113125

114126
def test_list_trials_empty(sagemaker_session):

‎tests/unit/sagemaker/experiments/test_trial.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import botocore
1516
import pytest
1617

1718
import datetime
@@ -133,11 +134,21 @@ def test_remove_trial_component(sagemaker_session):
133134

134135

135136
@patch("sagemaker.experiments.trial._Trial.load")
136-
def test_load_or_create_when_exist(mock_load):
137+
@patch("sagemaker.experiments.trial._Trial.create")
138+
def test_load_or_create_when_exist(mock_create, mock_load):
137139
sagemaker_session = Session()
138140
trial_name = "trial_name"
139141
exp_name = "exp_name"
140-
142+
exists_error = botocore.exceptions.ClientError(
143+
error_response={
144+
"Error": {
145+
"Code": "ValidationException",
146+
"Message": "Experiment with name (experiment-xyz) already exists.",
147+
}
148+
},
149+
operation_name="foo",
150+
)
151+
mock_create.side_effect = exists_error
141152
# The trial exists and experiment matches
142153
mock_load.return_value = _Trial(
143154
trial_name=trial_name,
@@ -147,6 +158,13 @@ def test_load_or_create_when_exist(mock_load):
147158
_Trial._load_or_create(
148159
trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session
149160
)
161+
mock_create.assert_called_once_with(
162+
trial_name=trial_name,
163+
experiment_name=exp_name,
164+
display_name=None,
165+
tags=None,
166+
sagemaker_session=sagemaker_session,
167+
)
150168
mock_load.assert_called_once_with(trial_name, sagemaker_session)
151169

152170
# The trial exists but experiment does not match
@@ -168,14 +186,8 @@ def test_load_or_create_when_exist(mock_load):
168186
@patch("sagemaker.experiments.trial._Trial.create")
169187
def test_load_or_create_when_not_exist(mock_create, mock_load):
170188
sagemaker_session = Session()
171-
client = sagemaker_session.sagemaker_client
172189
trial_name = "trial_name"
173190
exp_name = "exp_name"
174-
not_found_err = client.exceptions.ResourceNotFound(
175-
error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}},
176-
operation_name="foo",
177-
)
178-
mock_load.side_effect = not_found_err
179191

180192
_Trial._load_or_create(
181193
trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session
@@ -188,6 +200,7 @@ def test_load_or_create_when_not_exist(mock_create, mock_load):
188200
tags=None,
189201
sagemaker_session=sagemaker_session,
190202
)
203+
mock_load.assert_not_called()
191204

192205

193206
def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj):

‎tests/unit/sagemaker/experiments/test_trial_component.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from unittest.mock import patch
1919

20+
import botocore
21+
2022
from sagemaker import Session
2123
from sagemaker.experiments import _api_types
2224
from sagemaker.experiments._api_types import (
@@ -300,11 +302,28 @@ def test_list_trial_components_call_args(sagemaker_session):
300302

301303

302304
@patch("sagemaker.experiments.trial_component._TrialComponent.load")
303-
def test_load_or_create_when_exist(mock_load, sagemaker_session):
305+
@patch("sagemaker.experiments.trial_component._TrialComponent.create")
306+
def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
304307
tc_name = "tc_name"
308+
exists_error = botocore.exceptions.ClientError(
309+
error_response={
310+
"Error": {
311+
"Code": "ValidationException",
312+
"Message": "Experiment with name (experiment-xyz) already exists.",
313+
}
314+
},
315+
operation_name="foo",
316+
)
317+
mock_create.side_effect = exists_error
305318
_, is_existed = _TrialComponent._load_or_create(
306319
trial_component_name=tc_name, sagemaker_session=sagemaker_session
307320
)
321+
mock_create.assert_called_once_with(
322+
trial_component_name=tc_name,
323+
display_name=None,
324+
tags=None,
325+
sagemaker_session=sagemaker_session,
326+
)
308327
assert is_existed
309328
mock_load.assert_called_once_with(
310329
tc_name,
@@ -316,13 +335,7 @@ def test_load_or_create_when_exist(mock_load, sagemaker_session):
316335
@patch("sagemaker.experiments.trial_component._TrialComponent.create")
317336
def test_load_or_create_when_not_exist(mock_create, mock_load):
318337
sagemaker_session = Session()
319-
client = sagemaker_session.sagemaker_client
320338
tc_name = "tc_name"
321-
not_found_err = client.exceptions.ResourceNotFound(
322-
error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}},
323-
operation_name="foo",
324-
)
325-
mock_load.side_effect = not_found_err
326339

327340
_, is_existed = _TrialComponent._load_or_create(
328341
trial_component_name=tc_name, sagemaker_session=sagemaker_session
@@ -335,6 +348,7 @@ def test_load_or_create_when_not_exist(mock_create, mock_load):
335348
tags=None,
336349
sagemaker_session=sagemaker_session,
337350
)
351+
mock_load.assert_not_called()
338352

339353

340354
def test_search(sagemaker_session):

‎tests/unit/test_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,15 +795,17 @@ def test_to_string():
795795
}
796796

797797

798-
def test_start_waiting(capfd):
798+
@patch("time.sleep", return_value=None)
799+
def test_start_waiting(patched_sleep, capfd):
799800
waiting_time = 1
800801
sagemaker.utils._start_waiting(waiting_time)
801802
out, _ = capfd.readouterr()
802803

803804
assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out
804805

805806

806-
def test_retry_with_backoff():
807+
@patch("time.sleep", return_value=None)
808+
def test_retry_with_backoff(patched_sleep):
807809
callable_func = Mock()
808810

809811
# Invalid input
@@ -824,6 +826,25 @@ def test_retry_with_backoff():
824826
callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val]
825827
assert retry_with_backoff(callable_func, 2) == func_return_val
826828

829+
# when retry on specific error, fail for other error on 1st try
830+
func_return_val = "Test Return"
831+
response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}}
832+
error = botocore.exceptions.ClientError(error_response=response, operation_name="foo")
833+
callable_func.side_effect = [error, func_return_val]
834+
with pytest.raises(botocore.exceptions.ClientError) as run_err:
835+
retry_with_backoff(callable_func, 2, botocore_client_error_code="AccessDeniedException")
836+
assert "ValidationException" in str(run_err)
837+
838+
# when retry on specific error, One retry passes
839+
func_return_val = "Test Return"
840+
response = {"Error": {"Code": "AccessDeniedException", "Message": "Access denied."}}
841+
error = botocore.exceptions.ClientError(error_response=response, operation_name="foo")
842+
callable_func.side_effect = [error, func_return_val]
843+
assert (
844+
retry_with_backoff(callable_func, 2, botocore_client_error_code="AccessDeniedException")
845+
== func_return_val
846+
)
847+
827848
# No retry
828849
callable_func.side_effect = None
829850
callable_func.return_value = func_return_val

0 commit comments

Comments
 (0)
Please sign in to comment.