Skip to content

Commit 607f85a

Browse files
fix: job wait logic to wait for tag propagation
1 parent c98b231 commit 607f85a

File tree

6 files changed

+182
-38
lines changed

6 files changed

+182
-38
lines changed

src/sagemaker/session.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
secondary_training_status_changed,
4343
secondary_training_status_message,
4444
sts_regional_endpoint,
45+
retries,
4546
)
4647
from sagemaker import exceptions
4748
from sagemaker.session_settings import SessionSettings
@@ -5812,19 +5813,56 @@ def _deploy_done(sagemaker_client, endpoint_name):
58125813

58135814
def _wait_until_training_done(callable_fn, desc, poll=5):
58145815
"""Placeholder docstring"""
5815-
job_desc, finished = callable_fn(desc)
5816+
elapsed_time = 0
5817+
finished = None
5818+
job_desc = desc
58165819
while not finished:
5817-
time.sleep(poll)
5818-
job_desc, finished = callable_fn(job_desc)
5820+
try:
5821+
elapsed_time += poll
5822+
time.sleep(poll)
5823+
job_desc, finished = callable_fn(job_desc)
5824+
except botocore.exceptions.ClientError as err:
5825+
# For initial 5 mins we accept/pass AccessDeniedException.
5826+
# The reason is to await tag propagation to avoid false AccessDenied claims for an access
5827+
# policy based on resource tags, The caveat here is for true AccessDenied cases the routine
5828+
# will fail after 5 mins
5829+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
5830+
LOGGER.warning(
5831+
"Received AccessDeniedException. This could mean the IAM role does not "
5832+
"have the resource permissions, in which case please add resource access "
5833+
"and retry. For cases where the role has tag based resource policy, "
5834+
"continuing to wait for tag propagation.."
5835+
)
5836+
continue
5837+
else:
5838+
raise err
58195839
return job_desc
58205840

58215841

58225842
def _wait_until(callable_fn, poll=5):
58235843
"""Placeholder docstring"""
5824-
result = callable_fn()
5844+
elapsed_time = 0
5845+
result = None
58255846
while result is None:
5826-
time.sleep(poll)
5827-
result = callable_fn()
5847+
try:
5848+
elapsed_time += poll
5849+
time.sleep(poll)
5850+
result = callable_fn()
5851+
except botocore.exceptions.ClientError as err:
5852+
# For initial 5 mins we accept/pass AccessDeniedException.
5853+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
5854+
# access policy based on resource tags, The caveat here is for true AccessDenied
5855+
# cases the routine will fail after 5 mins
5856+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
5857+
LOGGER.warning(
5858+
"Received AccessDeniedException. This could mean the IAM role does not "
5859+
"have the resource permissions, in which case please add resource access "
5860+
"and retry. For cases where the role has tag based resource policy, "
5861+
"continuing to wait for tag propagation.."
5862+
)
5863+
continue
5864+
else:
5865+
raise err
58285866
return result
58295867

58305868

src/sagemaker/workflow/pipeline.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,7 @@ def upsert(
246246
except ClientError as ce:
247247
error_code = ce.response["Error"]["Code"]
248248
error_message = ce.response["Error"]["Message"]
249-
if not (
250-
error_code == "ValidationException"
251-
and "already exists" in error_message
252-
):
249+
if not (error_code == "ValidationException" and "already exists" in error_message):
253250
raise ce
254251
# already exists
255252
response = self.update(role_arn, description)

tests/unit/common.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
from botocore.exceptions import ClientError
17+
18+
19+
def _raise_unexpected_client_error(**kwargs):
20+
response = {
21+
"Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."}
22+
}
23+
raise ClientError(error_response=response, operation_name="foo")
24+
25+
26+
def _raise_does_not_exist_client_error(**kwargs):
27+
response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}}
28+
raise ClientError(error_response=response, operation_name="foo")
29+
30+
31+
def _raise_does_already_exists_client_error(**kwargs):
32+
response = {"Error": {"Code": "ValidationException", "Message": "Resource already exists."}}
33+
raise ClientError(error_response=response, operation_name="foo")
34+
35+
36+
def _raise_access_denied_client_error(**kwargs):
37+
response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}}
38+
raise ClientError(error_response=response, operation_name="foo")

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from botocore.exceptions import ClientError
3737

3838

39-
4039
@pytest.fixture
4140
def role_arn():
4241
return "arn:role"
@@ -207,8 +206,10 @@ def _raise_does_already_exists_client_error(**kwargs):
207206
pipeline.upsert(role_arn=role_arn, tags=tags)
208207

209208
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
210-
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
211-
Tags=tags
209+
PipelineName="MyPipeline",
210+
RoleArn=role_arn,
211+
PipelineDefinition=pipeline.definition(),
212+
Tags=tags,
212213
)
213214

214215
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with(
@@ -223,6 +224,7 @@ def _raise_does_already_exists_client_error(**kwargs):
223224
ResourceArn="mock_pipeline_arn", Tags=tags
224225
)
225226

227+
226228
def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn):
227229

228230
# case 2: unexpected failure on create
@@ -258,16 +260,17 @@ def _raise_unexpected_client_error(**kwargs):
258260
with pytest.raises(ClientError):
259261
pipeline.upsert(role_arn=role_arn, tags=tags)
260262

261-
262-
263263
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
264-
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
265-
Tags=tags
264+
PipelineName="MyPipeline",
265+
RoleArn=role_arn,
266+
PipelineDefinition=pipeline.definition(),
267+
Tags=tags,
266268
)
267269
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called()
268270
sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called()
269271
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
270272

273+
271274
def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn):
272275

273276
# case 3: resource does not exist
@@ -295,11 +298,13 @@ def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn)
295298
try:
296299
pipeline.upsert(role_arn=role_arn, tags=tags)
297300
except ClientError:
298-
assert False, f"Unexpected ClientError raised"
301+
assert False, "Unexpected ClientError raised"
299302

300303
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
301-
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
302-
Tags=tags
304+
PipelineName="MyPipeline",
305+
RoleArn=role_arn,
306+
PipelineDefinition=pipeline.definition(),
307+
Tags=tags,
303308
)
304309

305310
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called()

tests/unit/test_endpoint_from_model_data.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
from mock import MagicMock, Mock
1818
from mock import patch
1919

20+
from .common import (
21+
_raise_unexpected_client_error,
22+
_raise_does_already_exists_client_error,
23+
_raise_does_not_exist_client_error,
24+
)
2025
import sagemaker
2126

2227
ENDPOINT_NAME = "myendpoint"
@@ -126,7 +131,7 @@ def test_model_and_endpoint_config_exist(name_from_image_mock, sagemaker_session
126131
wait=False,
127132
)
128133
except ClientError:
129-
assert False, f"Unexpected ClientError raised for resource already exists scenario"
134+
assert False, "Unexpected ClientError raised for resource already exists scenario"
130135

131136
sagemaker_session.create_model.assert_called_once_with(
132137
name=NAME_FROM_IMAGE,
@@ -208,20 +213,3 @@ def test_entity_doesnt_exist():
208213
def test_describe_failure():
209214
with pytest.raises(ClientError):
210215
sagemaker.session._deployment_entity_exists(_raise_unexpected_client_error)
211-
212-
213-
def _raise_unexpected_client_error(**kwargs):
214-
response = {
215-
"Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."}
216-
}
217-
raise ClientError(error_response=response, operation_name="foo")
218-
219-
220-
def _raise_does_not_exist_client_error(**kwargs):
221-
response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}}
222-
raise ClientError(error_response=response, operation_name="foo")
223-
224-
225-
def _raise_does_already_exists_client_error(**kwargs):
226-
response = {"Error": {"Code": "ValidationException", "Message": "Resource already exists."}}
227-
raise ClientError(error_response=response, operation_name="foo")

tests/unit/test_session.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
from botocore.exceptions import ClientError
2323
from mock import ANY, MagicMock, Mock, patch, call, mock_open
2424

25+
from .common import _raise_unexpected_client_error
2526
import sagemaker
2627
from sagemaker import TrainingInput, Session, get_execution_role, exceptions
2728
from sagemaker.async_inference import AsyncInferenceConfig
2829
from sagemaker.session import (
2930
_tuning_job_status,
3031
_transform_job_status,
3132
_train_done,
33+
_wait_until,
34+
_wait_until_training_done,
3235
NOTEBOOK_METADATA_FILE,
3336
)
3437
from sagemaker.tuner import WarmStartConfig, WarmStartTypes
@@ -2342,6 +2345,81 @@ def test_train_done_in_progress(sagemaker_session):
23422345
assert training_finished is False
23432346

23442347

2348+
@patch("time.sleep", return_value=None)
2349+
def test_wait_until_training_done_raises_other_exception(patched_sleep):
2350+
response = {"Error": {"Code": "ValidationException", "Message": "Could not access entity."}}
2351+
mock_func = Mock(
2352+
name="describe_training_job",
2353+
side_effect=ClientError(error_response=response, operation_name="foo"),
2354+
)
2355+
desc = "dummy"
2356+
with pytest.raises(ClientError) as error:
2357+
_wait_until_training_done(mock_func, desc)
2358+
2359+
mock_func.assert_called_once()
2360+
assert "ValidationException" in str(error)
2361+
2362+
2363+
@patch("time.sleep", return_value=None)
2364+
def test_wait_until_training_done_tag_propagation(patched_sleep):
2365+
response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}}
2366+
side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 3
2367+
side_effect_iter.append(("result", "result"))
2368+
mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter)
2369+
desc = "dummy"
2370+
result = _wait_until_training_done(mock_func, desc)
2371+
assert result == "result"
2372+
assert mock_func.call_count == 4
2373+
2374+
2375+
@patch("time.sleep", return_value=None)
2376+
def test_wait_until_training_done_fail_access_denied_after_5_mins(patched_sleep):
2377+
response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}}
2378+
side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 70
2379+
mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter)
2380+
desc = "dummy"
2381+
with pytest.raises(ClientError) as error:
2382+
_wait_until_training_done(mock_func, desc)
2383+
2384+
# mock_func should be retried 300(elapsed time)/5(default poll delay) = 60 times
2385+
assert mock_func.call_count == 61
2386+
assert "AccessDeniedException" in str(error)
2387+
2388+
2389+
@patch("time.sleep", return_value=None)
2390+
def test_wait_until_raises_other_exception(patched_sleep):
2391+
mock_func = Mock(name="describe_training_job", side_effect=_raise_unexpected_client_error)
2392+
with pytest.raises(ClientError) as error:
2393+
_wait_until(mock_func)
2394+
2395+
mock_func.assert_called_once()
2396+
assert "ValidationException" in str(error)
2397+
2398+
2399+
@patch("time.sleep", return_value=None)
2400+
def test_wait_until_tag_propagation(patched_sleep):
2401+
response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}}
2402+
side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 3
2403+
side_effect_iter.append("result")
2404+
mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter)
2405+
result = _wait_until(mock_func)
2406+
assert result == "result"
2407+
assert mock_func.call_count == 4
2408+
2409+
2410+
@patch("time.sleep", return_value=None)
2411+
def test_wait_until_fail_access_denied_after_5_mins(patched_sleep):
2412+
response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}}
2413+
side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 70
2414+
mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter)
2415+
with pytest.raises(ClientError) as error:
2416+
_wait_until(mock_func)
2417+
2418+
# mock_func should be retried 300(elapsed time)/5(default poll delay) = 60 times
2419+
assert mock_func.call_count == 61
2420+
assert "AccessDeniedException" in str(error)
2421+
2422+
23452423
DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS = {
23462424
"AutoMLJobName": JOB_NAME,
23472425
"InputDataConfig": [

0 commit comments

Comments
 (0)