Skip to content

Feat: Added support for returing most recently created approved mp #5092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4463,6 +4463,49 @@ def wait_for_model_package(self, model_package_name, poll=5):
)
return desc

def get_most_recently_created_approved_model_package(self, model_package_group_name):
"""Returns the most recently created and Approved model package in a model package group

Args:
model_package_group_name (str): Name or Arn of the model package group

Returns:
dict: Returns a "sagemaker.model.ModelPackage" value.
"""

approved_model_packages = self.sagemaker_client.list_model_packages(
ModelPackageGroupName=model_package_group_name,
ModelApprovalStatus="Approved",
SortBy="CreationTime",
SortOrder="Descending",
MaxResults=1,
)
next_token = approved_model_packages.get("NextToken")

while (
len(approved_model_packages.get("ModelPackageSummaryList")) == 0
and next_token is not None
and next_token != ""
):
approved_model_packages = self.sagemaker_client.list_model_packages(
ModelPackageGroupName=model_package_group_name,
ModelApprovalStatus="Approved",
SortBy="CreationTime",
SortOrder="Descending",
MaxResults=1,
NextToken=next_token,
)
next_token = approved_model_packages.get("NextToken")

if len(approved_model_packages.get("ModelPackageSummaryList")) == 0:
return None

return sagemaker.model.ModelPackage(
model_package_arn=approved_model_packages.get("ModelPackageSummaryList")[0].get(
"ModelPackageArn"
)
)

def describe_model(self, name):
"""Calls the DescribeModel API for the given model name.

Expand Down
62 changes: 61 additions & 1 deletion tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import boto3
from botocore.config import Config

from sagemaker import Session
from sagemaker import Session, ModelPackage
from sagemaker.utils import unique_name_from_base

CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"

Expand Down Expand Up @@ -44,3 +45,62 @@ def test_sagemaker_session_does_not_create_bucket_on_init(

s3 = boto3.resource("s3", region_name=boto_session.region_name)
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None


def test_sagemaker_session_to_return_most_recent_approved_model_package(sagemaker_session):
model_package_group_name = unique_name_from_base("test-model-package-group")
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name=model_package_group_name
)
assert approved_model_package is None
sagemaker_session.sagemaker_client.create_model_package_group(
ModelPackageGroupName=model_package_group_name
)
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name=model_package_group_name
)
assert approved_model_package is None
source_uri = "dummy source uri"
model_package = sagemaker_session.sagemaker_client.create_model_package(
ModelPackageGroupName=model_package_group_name, SourceUri=source_uri
)
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name=model_package_group_name
)
assert approved_model_package is None
ModelPackage(
sagemaker_session=sagemaker_session,
model_package_arn=model_package["ModelPackageArn"],
).update_approval_status(approval_status="Approved")
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name=model_package_group_name
)
assert approved_model_package is not None
assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn")
model_package_2 = sagemaker_session.sagemaker_client.create_model_package(
ModelPackageGroupName=model_package_group_name, SourceUri=source_uri
)
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name=model_package_group_name
)
assert approved_model_package is not None
assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn")
ModelPackage(
sagemaker_session=sagemaker_session,
model_package_arn=model_package_2["ModelPackageArn"],
).update_approval_status(approval_status="Approved")
approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name=model_package_group_name
)
assert approved_model_package is not None
assert approved_model_package.model_package_arn == model_package_2.get("ModelPackageArn")

sagemaker_session.sagemaker_client.delete_model_package(
ModelPackageName=model_package_2["ModelPackageArn"]
)
sagemaker_session.sagemaker_client.delete_model_package(
ModelPackageName=model_package["ModelPackageArn"]
)
sagemaker_session.sagemaker_client.delete_model_package_group(
ModelPackageGroupName=model_package_group_name
)
32 changes: 32 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7253,3 +7253,35 @@ def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagem
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
ModelPackageGroupName="mock-mpg"
)


def test_get_most_recently_created_approved_model_package(sagemaker_session):
sagemaker_session.sagemaker_client.list_model_packages.side_effect = [
(
{
"ModelPackageSummaryList": [],
"NextToken": "NextToken",
}
),
(
{
"ModelPackageSummaryList": [
{
"CreationTime": 1697440162,
"ModelApprovalStatus": "Approved",
"ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3",
"ModelPackageGroupName": "model-version",
"ModelPackageVersion": 3,
},
],
}
),
]
model_package = sagemaker_session.get_most_recently_created_approved_model_package(
model_package_group_name="mpg"
)
assert model_package is not None
assert (
model_package.model_package_arn
== "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3"
)