Skip to content

Commit 8ef91ab

Browse files
author
Keshav Chandak
committed
bugix: Added check for the presence of model package group before creating one
1 parent bbbb76b commit 8ef91ab

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

src/sagemaker/session.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -4347,11 +4347,39 @@ def submit(request):
43474347
if model_package_group_name is not None and not model_package_group_name.startswith(
43484348
"arn:"
43494349
):
4350-
_create_resource(
4351-
lambda: self.sagemaker_client.create_model_package_group(
4352-
ModelPackageGroupName=request["ModelPackageGroupName"]
4350+
model_package_groups = []
4351+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4352+
NameContains=request["ModelPackageGroupName"],
4353+
)
4354+
model_package_groups = (
4355+
model_package_groups
4356+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
4357+
)
4358+
next_token = model_package_groups_response.get("NextToken")
4359+
4360+
while next_token is not None and next_token != "":
4361+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4362+
NameContains=request["ModelPackageGroupName"], NextToken=next_token
4363+
)
4364+
model_package_groups = (
4365+
model_package_groups
4366+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
4367+
)
4368+
next_token = model_package_groups_response.get("NextToken")
4369+
4370+
filtered_model_package_group = list(
4371+
filter(
4372+
lambda mpg: mpg.get("ModelPackageGroupName")
4373+
== request["ModelPackageGroupName"],
4374+
model_package_groups,
43534375
)
43544376
)
4377+
if not filtered_model_package_group:
4378+
_create_resource(
4379+
lambda: self.sagemaker_client.create_model_package_group(
4380+
ModelPackageGroupName=request["ModelPackageGroupName"]
4381+
)
4382+
)
43554383
if "SourceUri" in request and request["SourceUri"] is not None:
43564384
# Remove inference spec from request if the
43574385
# given source uri can lead to auto-population of it

tests/unit/test_session.py

+44
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,9 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session)
50065006
domain = "COMPUTER_VISION"
50075007
task = "IMAGE_CLASSIFICATION"
50085008
sample_payload_url = "s3://test-bucket/model"
5009+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5010+
"ModelPackageGroupSummaryList": []
5011+
}
50095012
sagemaker_session.create_model_package_from_containers(
50105013
containers=containers,
50115014
content_types=content_types,
@@ -5094,6 +5097,10 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec
50945097
skip_model_validation = "All"
50955098
source_uri = "dummy-source-uri"
50965099

5100+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5101+
"ModelPackageGroupSummaryList": []
5102+
}
5103+
50975104
created_versioned_mp_arn = (
50985105
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
50995106
)
@@ -5149,6 +5156,9 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp
51495156
approval_status = ("Approved",)
51505157
skip_model_validation = "All"
51515158
source_uri = "dummy-source-uri"
5159+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5160+
"ModelPackageGroupSummaryList": []
5161+
}
51525162

51535163
with pytest.raises(
51545164
ValueError,
@@ -5221,6 +5231,10 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake
52215231
return_value={"ModelPackageArn": created_versioned_mp_arn}
52225232
)
52235233

5234+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5235+
"ModelPackageGroupSummaryList": []
5236+
}
5237+
52245238
sagemaker_session.create_model_package_from_containers(
52255239
model_package_group_name=model_package_group_name,
52265240
containers=containers,
@@ -5443,6 +5457,9 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
54435457
approval_status = ("Approved",)
54445458
description = "description"
54455459
customer_metadata_properties = {"key1": "value1"}
5460+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5461+
"ModelPackageGroupSummaryList": []
5462+
}
54465463
sagemaker_session.create_model_package_from_containers(
54475464
containers=containers,
54485465
content_types=content_types,
@@ -5510,6 +5527,9 @@ def test_create_model_package_from_containers_with_one_instance_types(
55105527
approval_status = ("Approved",)
55115528
description = "description"
55125529
customer_metadata_properties = {"key1": "value1"}
5530+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
5531+
"ModelPackageGroupSummaryList": []
5532+
}
55135533
sagemaker_session.create_model_package_from_containers(
55145534
containers=containers,
55155535
content_types=content_types,
@@ -7183,3 +7203,27 @@ def test_delete_hub_content_reference(sagemaker_session):
71837203
}
71847204

71857205
sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request)
7206+
7207+
7208+
def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session):
7209+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
7210+
"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}]
7211+
}
7212+
sagemaker_session.create_model_package_from_containers(
7213+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7214+
)
7215+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7216+
sagemaker_session.create_model_package_from_containers(
7217+
source_uri="mock-source-uri",
7218+
model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg",
7219+
)
7220+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7221+
sagemaker_session.sagemaker_client.list_model_package_groups.return_value = {
7222+
"ModelPackageGroupSummaryList": []
7223+
}
7224+
sagemaker_session.create_model_package_from_containers(
7225+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7226+
)
7227+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
7228+
ModelPackageGroupName="mock-mpg"
7229+
)

0 commit comments

Comments
 (0)