Skip to content

Commit 67e5e1b

Browse files
mufaddal-rohawalaNamrata Madan
authored and
Namrata Madan
committed
fix: tag permission issue - remove describe before create (aws#3662)
* fix: tag permission issue - remove describe before create * Revert "fix: Add retry in session.py to check if training is finished (aws#3285)" This reverts commit 5bc3ccf. * fix: job wait logic to wait for tag propagation * fix: add tag fixes for pipelines and experiments
1 parent 4bd3eac commit 67e5e1b

18 files changed

+588
-198
lines changed

src/sagemaker/experiments/experiment.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments.trial import _Trial
2022
from sagemaker.experiments.trial_component import _TrialComponent
@@ -154,17 +156,21 @@ def _load_or_create(
154156
Returns:
155157
experiments.experiment._Experiment: A SageMaker `_Experiment` object
156158
"""
157-
sagemaker_client = sagemaker_session.sagemaker_client
158159
try:
159-
experiment = _Experiment.load(experiment_name, sagemaker_session)
160-
except sagemaker_client.exceptions.ResourceNotFound:
161160
experiment = _Experiment.create(
162161
experiment_name=experiment_name,
163162
display_name=display_name,
164163
description=description,
165164
tags=tags,
166165
sagemaker_session=sagemaker_session,
167166
)
167+
except ClientError as ce:
168+
error_code = ce.response["Error"]["Code"]
169+
error_message = ce.response["Error"]["Message"]
170+
if not (error_code == "ValidationException" and "already exists" in error_message):
171+
raise ce
172+
# already exists
173+
experiment = _Experiment.load(experiment_name, sagemaker_session)
168174
return experiment
169175

170176
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):

src/sagemaker/experiments/trial.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Contains the Trial class."""
1414
from __future__ import absolute_import
1515

16+
from botocore.exceptions import ClientError
17+
1618
from sagemaker.apiutils import _base_types
1719
from sagemaker.experiments import _api_types
1820
from sagemaker.experiments.trial_component import _TrialComponent
@@ -268,8 +270,20 @@ def _load_or_create(
268270
Returns:
269271
experiments.trial._Trial: A SageMaker `_Trial` object
270272
"""
271-
sagemaker_client = sagemaker_session.sagemaker_client
272273
try:
274+
trial = _Trial.create(
275+
experiment_name=experiment_name,
276+
trial_name=trial_name,
277+
display_name=display_name,
278+
tags=tags,
279+
sagemaker_session=sagemaker_session,
280+
)
281+
except ClientError as ce:
282+
error_code = ce.response["Error"]["Code"]
283+
error_message = ce.response["Error"]["Message"]
284+
if not (error_code == "ValidationException" and "already exists" in error_message):
285+
raise ce
286+
# already exists
273287
trial = _Trial.load(trial_name, sagemaker_session)
274288
if trial.experiment_name != experiment_name: # pylint: disable=no-member
275289
raise ValueError(
@@ -278,12 +292,4 @@ def _load_or_create(
278292
trial.experiment_name # pylint: disable=no-member
279293
)
280294
)
281-
except sagemaker_client.exceptions.ResourceNotFound:
282-
trial = _Trial.create(
283-
experiment_name=experiment_name,
284-
trial_name=trial_name,
285-
display_name=display_name,
286-
tags=tags,
287-
sagemaker_session=sagemaker_session,
288-
)
289295
return trial

src/sagemaker/experiments/trial_component.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments import _api_types
2022
from sagemaker.experiments._api_types import TrialComponentSearchResult
@@ -326,16 +328,20 @@ def _load_or_create(
326328
experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
327329
bool: A boolean variable indicating whether the trail component already exists
328330
"""
329-
sagemaker_client = sagemaker_session.sagemaker_client
330331
is_existed = False
331332
try:
332-
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
333-
is_existed = True
334-
except sagemaker_client.exceptions.ResourceNotFound:
335333
run_tc = _TrialComponent.create(
336334
trial_component_name=trial_component_name,
337335
display_name=display_name,
338336
tags=tags,
339337
sagemaker_session=sagemaker_session,
340338
)
339+
except ClientError as ce:
340+
error_code = ce.response["Error"]["Code"]
341+
error_message = ce.response["Error"]["Message"]
342+
if not (error_code == "ValidationException" and "already exists" in error_message):
343+
raise ce
344+
# already exists
345+
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
346+
is_existed = True
341347
return run_tc, is_existed

src/sagemaker/session.py

+105-55
Original file line numberDiff line numberDiff line change
@@ -3223,14 +3223,11 @@ def create_model_package_from_containers(
32233223

32243224
def submit(request):
32253225
if model_package_group_name is not None:
3226-
try:
3227-
self.sagemaker_client.describe_model_package_group(
3228-
ModelPackageGroupName=request["ModelPackageGroupName"]
3229-
)
3230-
except ClientError:
3231-
self.sagemaker_client.create_model_package_group(
3226+
_create_resource(
3227+
lambda: self.sagemaker_client.create_model_package_group(
32323228
ModelPackageGroupName=request["ModelPackageGroupName"]
32333229
)
3230+
)
32343231
return self.sagemaker_client.create_model_package(**request)
32353232

32363233
return self._intercept_create_request(
@@ -3918,42 +3915,40 @@ def endpoint_from_model_data(
39183915
name = name or name_from_image(image_uri)
39193916
model_vpc_config = vpc_utils.sanitize(model_vpc_config)
39203917

3921-
if _deployment_entity_exists(
3922-
lambda: self.sagemaker_client.describe_endpoint(EndpointName=name)
3923-
):
3924-
raise ValueError(
3925-
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
3926-
)
3918+
primary_container = container_def(
3919+
image_uri=image_uri,
3920+
model_data_url=model_s3_location,
3921+
env=model_environment_vars,
3922+
)
39273923

3928-
if not _deployment_entity_exists(
3929-
lambda: self.sagemaker_client.describe_model(ModelName=name)
3930-
):
3931-
primary_container = container_def(
3932-
image_uri=image_uri,
3933-
model_data_url=model_s3_location,
3934-
env=model_environment_vars,
3935-
)
3936-
self.create_model(
3937-
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
3938-
)
3924+
self.create_model(
3925+
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
3926+
)
39393927

39403928
data_capture_config_dict = None
39413929
if data_capture_config is not None:
39423930
data_capture_config_dict = data_capture_config._to_request_dict()
39433931

3944-
if not _deployment_entity_exists(
3945-
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
3946-
):
3947-
self.create_endpoint_config(
3932+
_create_resource(
3933+
lambda: self.create_endpoint_config(
39483934
name=name,
39493935
model_name=name,
39503936
initial_instance_count=initial_instance_count,
39513937
instance_type=instance_type,
39523938
accelerator_type=accelerator_type,
39533939
data_capture_config_dict=data_capture_config_dict,
39543940
)
3941+
)
3942+
3943+
# to make change backwards compatible
3944+
response = _create_resource(
3945+
lambda: self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
3946+
)
3947+
if not response:
3948+
raise ValueError(
3949+
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
3950+
)
39553951

3956-
self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
39573952
return name
39583953

39593954
def endpoint_from_production_variants(
@@ -5452,34 +5447,54 @@ def _deployment_entity_exists(describe_fn):
54525447
return False
54535448

54545449

5450+
def _create_resource(create_fn):
5451+
"""Call create function and accepts/pass when resource already exists.
5452+
5453+
This is a helper function to use an existing resource if found when creating.
5454+
5455+
Args:
5456+
create_fn: Create resource function.
5457+
5458+
Returns:
5459+
(bool): True if new resource was created, False if resource already exists.
5460+
"""
5461+
try:
5462+
create_fn()
5463+
# create function succeeded, resource does not exist already
5464+
return True
5465+
except ClientError as ce:
5466+
error_code = ce.response["Error"]["Code"]
5467+
error_message = ce.response["Error"]["Message"]
5468+
already_exists_exceptions = ["ValidationException", "ResourceInUse"]
5469+
already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
5470+
if not (
5471+
error_code in already_exists_exceptions
5472+
and any(p in error_message for p in already_exists_msg_patterns)
5473+
):
5474+
raise ce
5475+
# no new resource created as resource already exists
5476+
return False
5477+
5478+
54555479
def _train_done(sagemaker_client, job_name, last_desc):
54565480
"""Placeholder docstring"""
54575481
in_progress_statuses = ["InProgress", "Created"]
54585482

5459-
for _ in retries(
5460-
max_retry_count=10, # 10*30 = 5min
5461-
exception_message_prefix="Waiting for schedule to leave 'Pending' status",
5462-
seconds_to_sleep=30,
5463-
):
5464-
try:
5465-
desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
5466-
status = desc["TrainingJobStatus"]
5483+
desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
5484+
status = desc["TrainingJobStatus"]
54675485

5468-
if secondary_training_status_changed(desc, last_desc):
5469-
print()
5470-
print(secondary_training_status_message(desc, last_desc), end="")
5471-
else:
5472-
print(".", end="")
5473-
sys.stdout.flush()
5486+
if secondary_training_status_changed(desc, last_desc):
5487+
print()
5488+
print(secondary_training_status_message(desc, last_desc), end="")
5489+
else:
5490+
print(".", end="")
5491+
sys.stdout.flush()
54745492

5475-
if status in in_progress_statuses:
5476-
return desc, False
5493+
if status in in_progress_statuses:
5494+
return desc, False
54775495

5478-
print()
5479-
return desc, True
5480-
except botocore.exceptions.ClientError as err:
5481-
if err.response["Error"]["Code"] == "AccessDeniedException":
5482-
pass
5496+
print()
5497+
return desc, True
54835498

54845499

54855500
def _processing_job_status(sagemaker_client, job_name):
@@ -5799,19 +5814,54 @@ def _deploy_done(sagemaker_client, endpoint_name):
57995814

58005815
def _wait_until_training_done(callable_fn, desc, poll=5):
58015816
"""Placeholder docstring"""
5802-
job_desc, finished = callable_fn(desc)
5817+
elapsed_time = 0
5818+
finished = None
5819+
job_desc = desc
58035820
while not finished:
5804-
time.sleep(poll)
5805-
job_desc, finished = callable_fn(job_desc)
5821+
try:
5822+
elapsed_time += poll
5823+
time.sleep(poll)
5824+
job_desc, finished = callable_fn(job_desc)
5825+
except botocore.exceptions.ClientError as err:
5826+
# For initial 5 mins we accept/pass AccessDeniedException.
5827+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
5828+
# access policy based on resource tags, The caveat here is for true AccessDenied
5829+
# cases the routine will fail after 5 mins
5830+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
5831+
LOGGER.warning(
5832+
"Received AccessDeniedException. This could mean the IAM role does not "
5833+
"have the resource permissions, in which case please add resource access "
5834+
"and retry. For cases where the role has tag based resource policy, "
5835+
"continuing to wait for tag propagation.."
5836+
)
5837+
continue
5838+
raise err
58065839
return job_desc
58075840

58085841

58095842
def _wait_until(callable_fn, poll=5):
58105843
"""Placeholder docstring"""
5811-
result = callable_fn()
5844+
elapsed_time = 0
5845+
result = None
58125846
while result is None:
5813-
time.sleep(poll)
5814-
result = callable_fn()
5847+
try:
5848+
elapsed_time += poll
5849+
time.sleep(poll)
5850+
result = callable_fn()
5851+
except botocore.exceptions.ClientError as err:
5852+
# For initial 5 mins we accept/pass AccessDeniedException.
5853+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
5854+
# access policy based on resource tags, The caveat here is for true AccessDenied
5855+
# cases the routine will fail after 5 mins
5856+
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
5857+
LOGGER.warning(
5858+
"Received AccessDeniedException. This could mean the IAM role does not "
5859+
"have the resource permissions, in which case please add resource access "
5860+
"and retry. For cases where the role has tag based resource policy, "
5861+
"continuing to wait for tag propagation.."
5862+
)
5863+
continue
5864+
raise err
58155865
return result
58165866

58175867

src/sagemaker/utils.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,17 @@ def retries(
604604
)
605605

606606

607-
def retry_with_backoff(callable_func, num_attempts=8):
607+
def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
608608
"""Retry with backoff until maximum attempts are reached
609609
610610
Args:
611611
callable_func (callable): The callable function to retry.
612-
num_attempts (int): The maximum number of attempts to retry.
612+
num_attempts (int): The maximum number of attempts to retry.(Default: 8)
613+
botocore_client_error_code (str): The specific Botocore ClientError exception error code
614+
on which to retry on.
615+
If provided other exceptions will be raised directly w/o retry.
616+
If not provided, retry on any exception.
617+
(Default: None)
613618
"""
614619
if num_attempts < 1:
615620
raise ValueError(
@@ -619,7 +624,15 @@ def retry_with_backoff(callable_func, num_attempts=8):
619624
try:
620625
return callable_func()
621626
except Exception as ex: # pylint: disable=broad-except
622-
if i == num_attempts - 1:
627+
if not botocore_client_error_code or (
628+
botocore_client_error_code
629+
and isinstance(ex, botocore.exceptions.ClientError)
630+
and ex.response["Error"]["Code"] # pylint: disable=no-member
631+
== botocore_client_error_code
632+
):
633+
if i == num_attempts - 1:
634+
raise ex
635+
else:
623636
raise ex
624637
logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
625638
time.sleep(2**i)

0 commit comments

Comments
 (0)