Skip to content

Commit e7ed90d

Browse files
fix: tag permission issue - remove describe before create
1 parent 48141db commit e7ed90d

File tree

4 files changed

+250
-90
lines changed

4 files changed

+250
-90
lines changed

src/sagemaker/session.py

+50-27
Original file line numberDiff line numberDiff line change
@@ -3223,14 +3223,11 @@ def create_model_package_from_containers(
32233223

32243224
def submit(request):
32253225
if model_package_group_name is not None:
3226-
try:
3227-
self.sagemaker_client.describe_model_package_group(
3228-
ModelPackageGroupName=request["ModelPackageGroupName"]
3229-
)
3230-
except ClientError:
3226+
_create_resource(
32313227
self.sagemaker_client.create_model_package_group(
32323228
ModelPackageGroupName=request["ModelPackageGroupName"]
32333229
)
3230+
)
32343231
return self.sagemaker_client.create_model_package(**request)
32353232

32363233
return self._intercept_create_request(
@@ -3918,42 +3915,40 @@ def endpoint_from_model_data(
39183915
name = name or name_from_image(image_uri)
39193916
model_vpc_config = vpc_utils.sanitize(model_vpc_config)
39203917

3921-
if _deployment_entity_exists(
3922-
lambda: self.sagemaker_client.describe_endpoint(EndpointName=name)
3923-
):
3924-
raise ValueError(
3925-
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
3926-
)
3918+
primary_container = container_def(
3919+
image_uri=image_uri,
3920+
model_data_url=model_s3_location,
3921+
env=model_environment_vars,
3922+
)
39273923

3928-
if not _deployment_entity_exists(
3929-
lambda: self.sagemaker_client.describe_model(ModelName=name)
3930-
):
3931-
primary_container = container_def(
3932-
image_uri=image_uri,
3933-
model_data_url=model_s3_location,
3934-
env=model_environment_vars,
3935-
)
3936-
self.create_model(
3937-
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
3938-
)
3924+
self.create_model(
3925+
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
3926+
)
39393927

39403928
data_capture_config_dict = None
39413929
if data_capture_config is not None:
39423930
data_capture_config_dict = data_capture_config._to_request_dict()
39433931

3944-
if not _deployment_entity_exists(
3945-
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
3946-
):
3947-
self.create_endpoint_config(
3932+
_create_resource(
3933+
lambda: self.create_endpoint_config(
39483934
name=name,
39493935
model_name=name,
39503936
initial_instance_count=initial_instance_count,
39513937
instance_type=instance_type,
39523938
accelerator_type=accelerator_type,
39533939
data_capture_config_dict=data_capture_config_dict,
39543940
)
3941+
)
3942+
3943+
# to make change backwards compatible
3944+
response = _create_resource(
3945+
lambda: self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
3946+
)
3947+
if not response:
3948+
raise ValueError(
3949+
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
3950+
)
39553951

3956-
self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
39573952
return name
39583953

39593954
def endpoint_from_production_variants(
@@ -5452,6 +5447,34 @@ def _deployment_entity_exists(describe_fn):
54525447
return False
54535448

54545449

5450+
def _create_resource(create_fn):
5451+
"""Call create function and while doing so accepts/passes the resource already exists exception.
5452+
Throws an exception if any exception other than resource already exists.
5453+
5454+
Args:
5455+
create_fn: Create resource function.
5456+
5457+
Returns:
5458+
(bool): True if new resource was created, False if resource already exists.
5459+
"""
5460+
try:
5461+
create_fn()
5462+
# create function succeeded, resource does not exist already
5463+
return True
5464+
except ClientError as ce:
5465+
error_code = ce.response["Error"]["Code"]
5466+
error_message = ce.response["Error"]["Message"]
5467+
already_exists_exceptions = ["ValidationException", "ResourceInUse"]
5468+
already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
5469+
if not (
5470+
error_code in already_exists_exceptions
5471+
and any(p in error_message for p in already_exists_msg_patterns)
5472+
):
5473+
raise ce
5474+
# no new resource created as resource already exists
5475+
return False
5476+
5477+
54555478
def _train_done(sagemaker_client, job_name, last_desc):
54565479
"""Placeholder docstring"""
54575480
in_progress_statuses = ["InProgress", "Created"]

src/sagemaker/workflow/pipeline.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,19 @@ def upsert(
241241
Returns:
242242
response dict from service
243243
"""
244-
exists = True
245244
try:
246-
self.describe()
247-
except ClientError as e:
248-
err = e.response.get("Error", {})
249-
if err.get("Code", None) == "ResourceNotFound":
250-
exists = False
251-
else:
252-
raise e
253-
254-
if not exists:
255245
response = self.create(role_arn, description, tags, parallelism_config)
256-
else:
246+
except ClientError as ce:
247+
error_code = ce.response["Error"]["Code"]
248+
error_message = ce.response["Error"]["Message"]
249+
if not (
250+
error_code == "ValidationException"
251+
and "already exists" in error_message
252+
):
253+
raise ce
254+
# already exists
257255
response = self.update(role_arn, description)
256+
# add new tags to existing resource
258257
if tags is not None:
259258
old_tags = self.sagemaker_session.sagemaker_client.list_tags(
260259
ResourceArn=response["PipelineArn"]

tests/unit/sagemaker/workflow/test_pipeline.py

+101-6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from sagemaker.workflow.step_collections import StepCollection
3434
from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep
3535
from sagemaker.local.local_session import LocalSession
36+
from botocore.exceptions import ClientError
37+
3638

3739

3840
@pytest.fixture
@@ -173,10 +175,17 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
173175
)
174176

175177

176-
def test_pipeline_upsert(sagemaker_session_mock, role_arn):
177-
sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = {
178-
"PipelineArn": "pipeline-arn"
179-
}
178+
def test_pipeline_upsert_resource_already_exists(sagemaker_session_mock, role_arn):
179+
180+
# case 1: resource already exists
181+
def _raise_does_already_exists_client_error(**kwargs):
182+
response = {"Error": {"Code": "ValidationException", "Message": "Resource already exists."}}
183+
raise ClientError(error_response=response, operation_name="create_pipeline")
184+
185+
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(
186+
name="create_pipeline", side_effect=_raise_does_already_exists_client_error
187+
)
188+
180189
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
181190
"PipelineArn": "pipeline-arn"
182191
}
@@ -197,9 +206,12 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn):
197206
]
198207
pipeline.upsert(role_arn=role_arn, tags=tags)
199208

200-
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_not_called()
209+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
210+
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
211+
Tags=tags
212+
)
201213

202-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
214+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with(
203215
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
204216
)
205217
assert sagemaker_session_mock.sagemaker_client.list_tags.called_with(
@@ -211,6 +223,89 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn):
211223
ResourceArn="mock_pipeline_arn", Tags=tags
212224
)
213225

226+
def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn):
227+
228+
# case 2: unexpected failure on create
229+
def _raise_unexpected_client_error(**kwargs):
230+
response = {
231+
"Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."}
232+
}
233+
raise ClientError(error_response=response, operation_name="foo")
234+
235+
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(
236+
name="create_pipeline", side_effect=_raise_unexpected_client_error
237+
)
238+
239+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
240+
"PipelineArn": "pipeline-arn"
241+
}
242+
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
243+
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
244+
}
245+
246+
tags = [
247+
{"Key": "foo", "Value": "abc"},
248+
{"Key": "bar", "Value": "xyz"},
249+
]
250+
251+
pipeline = Pipeline(
252+
name="MyPipeline",
253+
parameters=[],
254+
steps=[],
255+
sagemaker_session=sagemaker_session_mock,
256+
)
257+
258+
with pytest.raises(ClientError):
259+
pipeline.upsert(role_arn=role_arn, tags=tags)
260+
261+
262+
263+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
264+
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
265+
Tags=tags
266+
)
267+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called()
268+
sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called()
269+
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
270+
271+
def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn):
272+
273+
# case 3: resource does not exist
274+
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline")
275+
276+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
277+
"PipelineArn": "pipeline-arn"
278+
}
279+
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
280+
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
281+
}
282+
283+
tags = [
284+
{"Key": "foo", "Value": "abc"},
285+
{"Key": "bar", "Value": "xyz"},
286+
]
287+
288+
pipeline = Pipeline(
289+
name="MyPipeline",
290+
parameters=[],
291+
steps=[],
292+
sagemaker_session=sagemaker_session_mock,
293+
)
294+
295+
try:
296+
pipeline.upsert(role_arn=role_arn, tags=tags)
297+
except ClientError:
298+
assert False, f"Unexpected ClientError raised"
299+
300+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
301+
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
302+
Tags=tags
303+
)
304+
305+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called()
306+
sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called()
307+
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
308+
214309

215310
def test_pipeline_delete(sagemaker_session_mock):
216311
pipeline = Pipeline(

0 commit comments

Comments
 (0)