Skip to content

Commit 71fe857

Browse files
author
Keshav Chandak
committed
Feat: Added support for returing most recently created approved model package in a group
1 parent 8a6ab21 commit 71fe857

File tree

3 files changed

+135
-1
lines changed

3 files changed

+135
-1
lines changed

src/sagemaker/session.py

+42
Original file line numberDiff line numberDiff line change
@@ -4463,6 +4463,48 @@ def wait_for_model_package(self, model_package_name, poll=5):
44634463
)
44644464
return desc
44654465

4466+
def get_most_recently_created_approved_model_package(self, model_package_group_name):
4467+
"""Returns the most recently created and Approved model package in a model package group
4468+
4469+
Args:
4470+
model_package_group_name (str): Name or Arn of the model package group
4471+
4472+
Returns:
4473+
dict: Returns a "sagemaker.model.ModelPackage" value.
4474+
"""
4475+
4476+
approved_model_packages = self.sagemaker_client.list_model_packages(
4477+
ModelPackageGroupName=model_package_group_name,
4478+
ModelApprovalStatus="Approved",
4479+
SortBy="CreationTime",
4480+
SortOrder="Descending",
4481+
MaxResults=1,
4482+
)
4483+
next_token = approved_model_packages.get("NextToken")
4484+
4485+
while (
4486+
len(approved_model_packages.get("ModelPackageSummaryList")) == 0
4487+
and next_token is not None
4488+
and next_token != ""
4489+
):
4490+
approved_model_packages = self.sagemaker_client.list_model_packages(
4491+
ModelPackageGroupName=model_package_group_name,
4492+
ModelApprovalStatus="Approved",
4493+
SortBy="CreationTime",
4494+
SortOrder="Descending",
4495+
MaxResults=1,
4496+
)
4497+
next_token = approved_model_packages.get("NextToken")
4498+
4499+
if len(approved_model_packages.get("ModelPackageSummaryList")) == 0:
4500+
return None
4501+
4502+
return sagemaker.model.ModelPackage(
4503+
model_package_arn=approved_model_packages.get("ModelPackageSummaryList")[0].get(
4504+
"ModelPackageArn"
4505+
)
4506+
)
4507+
44664508
def describe_model(self, name):
44674509
"""Calls the DescribeModel API for the given model name.
44684510

tests/integ/test_session.py

+61-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import boto3
1616
from botocore.config import Config
1717

18-
from sagemaker import Session
18+
from sagemaker import Session, ModelPackage
19+
from sagemaker.utils import unique_name_from_base
1920

2021
CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"
2122

@@ -44,3 +45,62 @@ def test_sagemaker_session_does_not_create_bucket_on_init(
4445

4546
s3 = boto3.resource("s3", region_name=boto_session.region_name)
4647
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None
48+
49+
50+
def test_sagemaker_session_to_return_most_recent_approved_model_package(sagemaker_session):
51+
model_package_group_name = unique_name_from_base("test-model-package-group")
52+
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
53+
model_package_group_name=model_package_group_name
54+
)
55+
assert approved_model_package is None
56+
sagemaker_session.sagemaker_client.create_model_package_group(
57+
ModelPackageGroupName=model_package_group_name
58+
)
59+
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
60+
model_package_group_name=model_package_group_name
61+
)
62+
assert approved_model_package is None
63+
source_uri = "dummy source uri"
64+
model_package = sagemaker_session.sagemaker_client.create_model_package(
65+
ModelPackageGroupName=model_package_group_name, SourceUri=source_uri
66+
)
67+
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
68+
model_package_group_name=model_package_group_name
69+
)
70+
assert approved_model_package is None
71+
ModelPackage(
72+
sagemaker_session=sagemaker_session,
73+
model_package_arn=model_package["ModelPackageArn"],
74+
).update_approval_status(approval_status="Approved")
75+
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
76+
model_package_group_name=model_package_group_name
77+
)
78+
assert approved_model_package is not None
79+
assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn")
80+
model_package_2 = sagemaker_session.sagemaker_client.create_model_package(
81+
ModelPackageGroupName=model_package_group_name, SourceUri=source_uri
82+
)
83+
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
84+
model_package_group_name=model_package_group_name
85+
)
86+
assert approved_model_package is not None
87+
assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn")
88+
ModelPackage(
89+
sagemaker_session=sagemaker_session,
90+
model_package_arn=model_package_2["ModelPackageArn"],
91+
).update_approval_status(approval_status="Approved")
92+
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
93+
model_package_group_name=model_package_group_name
94+
)
95+
assert approved_model_package is not None
96+
assert approved_model_package.model_package_arn == model_package_2.get("ModelPackageArn")
97+
98+
sagemaker_session.sagemaker_client.delete_model_package(
99+
ModelPackageName=model_package_2["ModelPackageArn"]
100+
)
101+
sagemaker_session.sagemaker_client.delete_model_package(
102+
ModelPackageName=model_package["ModelPackageArn"]
103+
)
104+
sagemaker_session.sagemaker_client.delete_model_package_group(
105+
ModelPackageGroupName=model_package_group_name
106+
)

tests/unit/test_session.py

+32
Original file line numberDiff line numberDiff line change
@@ -7253,3 +7253,35 @@ def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagem
72537253
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
72547254
ModelPackageGroupName="mock-mpg"
72557255
)
7256+
7257+
7258+
def test_get_most_recently_created_approved_model_package(sagemaker_session):
7259+
sagemaker_session.sagemaker_client.list_model_packages.side_effect = [
7260+
(
7261+
{
7262+
"ModelPackageSummaryList": [],
7263+
"NextToken": "NextToken",
7264+
}
7265+
),
7266+
(
7267+
{
7268+
"ModelPackageSummaryList": [
7269+
{
7270+
"CreationTime": 1697440162,
7271+
"ModelApprovalStatus": "Approved",
7272+
"ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/model-version/3",
7273+
"ModelPackageGroupName": "model-version",
7274+
"ModelPackageVersion": 3,
7275+
},
7276+
],
7277+
}
7278+
),
7279+
]
7280+
model_package = sagemaker_session.get_most_recently_created_approved_model_package(
7281+
model_package_group_name="mpg"
7282+
)
7283+
assert model_package is not None
7284+
assert (
7285+
model_package.model_package_arn
7286+
== "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/model-version/3"
7287+
)

0 commit comments

Comments
 (0)