From 1f5534972894cae8c5f10a5bbfd8de61a5baca7b Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Tue, 18 Mar 2025 16:51:28 +0530 Subject: [PATCH] Feat: Added support for returing most recently created approved model package in a group --- src/sagemaker/session.py | 43 +++++++++++++++++++++++++ tests/integ/test_session.py | 62 ++++++++++++++++++++++++++++++++++++- tests/unit/test_session.py | 32 +++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 38fa7f8c26..797d559348 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4463,6 +4463,49 @@ def wait_for_model_package(self, model_package_name, poll=5): ) return desc + def get_most_recently_created_approved_model_package(self, model_package_group_name): + """Returns the most recently created and Approved model package in a model package group + + Args: + model_package_group_name (str): Name or Arn of the model package group + + Returns: + dict: Returns a "sagemaker.model.ModelPackage" value. + """ + + approved_model_packages = self.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=1, + ) + next_token = approved_model_packages.get("NextToken") + + while ( + len(approved_model_packages.get("ModelPackageSummaryList")) == 0 + and next_token is not None + and next_token != "" + ): + approved_model_packages = self.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name, + ModelApprovalStatus="Approved", + SortBy="CreationTime", + SortOrder="Descending", + MaxResults=1, + NextToken=next_token, + ) + next_token = approved_model_packages.get("NextToken") + + if len(approved_model_packages.get("ModelPackageSummaryList")) == 0: + return None + + return sagemaker.model.ModelPackage( + model_package_arn=approved_model_packages.get("ModelPackageSummaryList")[0].get( + "ModelPackageArn" + ) + ) + def describe_model(self, name): """Calls the DescribeModel API for the given model name. diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 0015efe3fd..0b2900bef7 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -15,7 +15,8 @@ import boto3 from botocore.config import Config -from sagemaker import Session +from sagemaker import Session, ModelPackage +from sagemaker.utils import unique_name_from_base CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist" @@ -44,3 +45,62 @@ def test_sagemaker_session_does_not_create_bucket_on_init( s3 = boto3.resource("s3", region_name=boto_session.region_name) assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None + + +def test_sagemaker_session_to_return_most_recent_approved_model_package(sagemaker_session): + model_package_group_name = unique_name_from_base("test-model-package-group") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + sagemaker_session.sagemaker_client.create_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + source_uri = "dummy source uri" + model_package = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is None + ModelPackage( + sagemaker_session=sagemaker_session, + model_package_arn=model_package["ModelPackageArn"], + ).update_approval_status(approval_status="Approved") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") + model_package_2 = sagemaker_session.sagemaker_client.create_model_package( + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri + ) + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") + ModelPackage( + sagemaker_session=sagemaker_session, + model_package_arn=model_package_2["ModelPackageArn"], + ).update_approval_status(approval_status="Approved") + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name=model_package_group_name + ) + assert approved_model_package is not None + assert approved_model_package.model_package_arn == model_package_2.get("ModelPackageArn") + + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package_2["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index f873e9b14c..e3d763e612 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7253,3 +7253,35 @@ def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagem sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with( ModelPackageGroupName="mock-mpg" ) + + +def test_get_most_recently_created_approved_model_package(sagemaker_session): + sagemaker_session.sagemaker_client.list_model_packages.side_effect = [ + ( + { + "ModelPackageSummaryList": [], + "NextToken": "NextToken", + } + ), + ( + { + "ModelPackageSummaryList": [ + { + "CreationTime": 1697440162, + "ModelApprovalStatus": "Approved", + "ModelPackageArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3", + "ModelPackageGroupName": "model-version", + "ModelPackageVersion": 3, + }, + ], + } + ), + ] + model_package = sagemaker_session.get_most_recently_created_approved_model_package( + model_package_group_name="mpg" + ) + assert model_package is not None + assert ( + model_package.model_package_arn + == "arn:aws:sagemaker:us-west-2:123456789012:model-package/model-version/3" + )