|
15 | 15 | import boto3
|
16 | 16 | from botocore.config import Config
|
17 | 17 |
|
18 |
| -from sagemaker import Session |
| 18 | +from sagemaker import Session, ModelPackage |
| 19 | +from sagemaker.utils import unique_name_from_base |
19 | 20 |
|
20 | 21 | CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"
|
21 | 22 |
|
@@ -44,3 +45,62 @@ def test_sagemaker_session_does_not_create_bucket_on_init(
|
44 | 45 |
|
45 | 46 | s3 = boto3.resource("s3", region_name=boto_session.region_name)
|
46 | 47 | assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None
|
| 48 | + |
| 49 | + |
| 50 | +def test_sagemaker_session_to_return_most_recent_approved_model_package(sagemaker_session): |
| 51 | + model_package_group_name = unique_name_from_base("test-model-package-group") |
| 52 | + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( |
| 53 | + model_package_group_name=model_package_group_name |
| 54 | + ) |
| 55 | + assert approved_model_package is None |
| 56 | + sagemaker_session.sagemaker_client.create_model_package_group( |
| 57 | + ModelPackageGroupName=model_package_group_name |
| 58 | + ) |
| 59 | + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( |
| 60 | + model_package_group_name=model_package_group_name |
| 61 | + ) |
| 62 | + assert approved_model_package is None |
| 63 | + source_uri = "dummy source uri" |
| 64 | + model_package = sagemaker_session.sagemaker_client.create_model_package( |
| 65 | + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri |
| 66 | + ) |
| 67 | + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( |
| 68 | + model_package_group_name=model_package_group_name |
| 69 | + ) |
| 70 | + assert approved_model_package is None |
| 71 | + ModelPackage( |
| 72 | + sagemaker_session=sagemaker_session, |
| 73 | + model_package_arn=model_package["ModelPackageArn"], |
| 74 | + ).update_approval_status(approval_status="Approved") |
| 75 | + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( |
| 76 | + model_package_group_name=model_package_group_name |
| 77 | + ) |
| 78 | + assert approved_model_package is not None |
| 79 | + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") |
| 80 | + model_package_2 = sagemaker_session.sagemaker_client.create_model_package( |
| 81 | + ModelPackageGroupName=model_package_group_name, SourceUri=source_uri |
| 82 | + ) |
| 83 | + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( |
| 84 | + model_package_group_name=model_package_group_name |
| 85 | + ) |
| 86 | + assert approved_model_package is not None |
| 87 | + assert approved_model_package.model_package_arn == model_package.get("ModelPackageArn") |
| 88 | + ModelPackage( |
| 89 | + sagemaker_session=sagemaker_session, |
| 90 | + model_package_arn=model_package_2["ModelPackageArn"], |
| 91 | + ).update_approval_status(approval_status="Approved") |
| 92 | + approved_model_package = sagemaker_session.get_most_recently_created_approved_model_package( |
| 93 | + model_package_group_name=model_package_group_name |
| 94 | + ) |
| 95 | + assert approved_model_package is not None |
| 96 | + assert approved_model_package.model_package_arn == model_package_2.get("ModelPackageArn") |
| 97 | + |
| 98 | + sagemaker_session.sagemaker_client.delete_model_package( |
| 99 | + ModelPackageName=model_package_2["ModelPackageArn"] |
| 100 | + ) |
| 101 | + sagemaker_session.sagemaker_client.delete_model_package( |
| 102 | + ModelPackageName=model_package["ModelPackageArn"] |
| 103 | + ) |
| 104 | + sagemaker_session.sagemaker_client.delete_model_package_group( |
| 105 | + ModelPackageGroupName=model_package_group_name |
| 106 | + ) |
0 commit comments