Skip to content

Commit 1f37c6d

Browse files
authored
fix: Fix cross account register model (#3726)
1 parent 8d282c1 commit 1f37c6d

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

src/sagemaker/session.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3222,7 +3222,9 @@ def create_model_package_from_containers(
32223222
)
32233223

32243224
def submit(request):
3225-
if model_package_group_name is not None:
3225+
if model_package_group_name is not None and not model_package_group_name.startswith(
3226+
"arn:"
3227+
):
32263228
_create_resource(
32273229
lambda: self.sagemaker_client.create_model_package_group(
32283230
ModelPackageGroupName=request["ModelPackageGroupName"]

tests/unit/test_session.py

+29
Original file line numberDiff line numberDiff line change
@@ -2650,6 +2650,35 @@ def test_create_model_package_from_containers(sagemaker_session):
26502650
sagemaker_session.sagemaker_client.create_model_package.assert_called_once()
26512651

26522652

2653+
def test_create_model_package_from_containers_cross_account_mpg_name(sagemaker_session):
2654+
mpg_name = "arn:aws:sagemaker:us-east-1:215995503607:model-package-group/stage-dev"
2655+
content_types = ["text/csv"]
2656+
response_types = ["text/csv"]
2657+
sagemaker_session.create_model_package_from_containers(
2658+
model_package_group_name=mpg_name,
2659+
content_types=content_types,
2660+
response_types=response_types,
2661+
)
2662+
sagemaker_session.sagemaker_client.create_model_package.assert_called_once()
2663+
2664+
2665+
def test_create_mpg_from_containers_cross_account_mpg_name(sagemaker_session):
2666+
mpg_name = "arn:aws:sagemaker:us-east-1:215995503607:model-package-group/stage-dev"
2667+
content_types = ["text/csv"]
2668+
response_types = ["text/csv"]
2669+
with pytest.raises(AssertionError) as error:
2670+
sagemaker_session.create_model_package_from_containers(
2671+
model_package_group_name=mpg_name,
2672+
content_types=content_types,
2673+
response_types=response_types,
2674+
)
2675+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_once()
2676+
assert (
2677+
"Expected 'create_model_package_group' to have been called once. "
2678+
"Called 0 times." == str(error)
2679+
)
2680+
2681+
26532682
def test_create_model_package_from_containers_name_conflict(sagemaker_session):
26542683
model_package_name = "sagemaker-model-package"
26552684
model_package_group_name = "sagemaker-model-package-group"

0 commit comments

Comments
 (0)