Skip to content

documentation: Add Model Registry Model Collection #3788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/api/inference/model_collection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Model Collection
----------------

.. automodule:: sagemaker.collection
:members:
:undoc-members:
:show-inheritance:

97 changes: 50 additions & 47 deletions src/sagemaker/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""This module contains code related to Amazon SageMaker Collection.
"""This module contains code related to Amazon SageMaker Collections in the Model Registry.

These Classes helps in providing features to maintain and create collections
Use these methods to help you create and maintain your Collections.
"""

from __future__ import absolute_import
Expand All @@ -27,27 +27,29 @@


class Collection(object):
"""Sets up Amazon SageMaker Collection."""
"""Sets up an Amazon SageMaker Collection."""

def __init__(self, sagemaker_session):
"""Initializes a Collection instance.

The collection provides a logical grouping for model groups
A Collection is a logical grouping of Model Groups.

Args:
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using
sagemaker_session (sagemaker.session.Session): A Session object which
manages interactions between Amazon SageMaker APIs and other
AWS services needed. If unspecified, a session is created using
the default AWS configuration chain.
"""

self.sagemaker_session = sagemaker_session or Session()

def _check_access_error(self, err: ClientError):
"""To check if the error is related to the access error and to provide the relavant message
"""Checks if the error is related to the access error and provide the relevant message.

Args:
err: The client error that needs to be checked
err: The client error to check.
"""

error_code = err.response["Error"]["Code"]
if error_code == "AccessDeniedException":
raise Exception(
Expand All @@ -57,12 +59,12 @@ def _check_access_error(self, err: ClientError):
)

def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value):
"""To add a model package group to a collection
"""Adds a Model Group to a Collection.

Args:
model_package_group (str): The name of the model package group
tag_rule_key (str): The tag key of the corresponing collection to be added into
tag_rule_value (str): The tag value of the corresponing collection to be added into
model_package_group (str): The name of the Model Group.
tag_rule_key (str): The tag key of the destination collection.
tag_rule_value (str): The tag value of the destination collection.
"""
model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group(
ModelPackageGroupName=model_package_group
Expand All @@ -78,11 +80,11 @@ def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value):
)

def _remove_model_group(self, model_package_group, tag_rule_key):
"""To remove a model package group from a collection
"""Removes a Model Group from a Collection.

Args:
model_package_group (str): The name of the model package group
tag_rule_key (str): The tag key of the corresponing collection to be removed from
model_package_group (str): The name of the Model Group
tag_rule_key (str): The tag key of the Collection from which to remove the Model Group.
"""
model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group(
ModelPackageGroupName=model_package_group
Expand All @@ -92,12 +94,12 @@ def _remove_model_group(self, model_package_group, tag_rule_key):
)

def create(self, collection_name: str, parent_collection_name: str = None):
"""Creates a collection
"""Creates a Collection.

Args:
collection_name (str): The name of the collection to be created
parent_collection_name (str): The name of the parent collection.
To be None if the collection is to be created on the root level
collection_name (str): The name of the Collection to create.
parent_collection_name (str): The name of the parent Collection.
Is ``None`` if the Collection is created at the root level.
"""

tag_rule_key = f"sagemaker:collection-path:{int(time.time() * 1000)}"
Expand Down Expand Up @@ -151,11 +153,11 @@ def create(self, collection_name: str, parent_collection_name: str = None):
raise

def delete(self, collections: List[str]):
"""Deletes a list of collection.
"""Deletes a list of Collections.

Args:
collections (List[str]): List of collections to be deleted
Only deletes a collection if it is empty
collections (List[str]): A list of Collections to delete.
Only deletes a Collection if it is empty.
"""

if len(collections) > 10:
Expand Down Expand Up @@ -201,7 +203,7 @@ def delete(self, collections: List[str]):
}

def _get_collection_tag_rule(self, collection_name: str):
"""Returns the tag rule key and value for a collection"""
"""Returns the tag rule key and value for a Collection."""

if collection_name is not None:
try:
Expand Down Expand Up @@ -230,11 +232,11 @@ def _get_collection_tag_rule(self, collection_name: str):
raise ValueError("Collection name is required")

def add_model_groups(self, collection_name: str, model_groups: List[str]):
"""To add list of model package groups to a collection
"""Adds a list of Model Groups to a Collection.

Args:
collection_name (str): The name of the collection
model_groups List[str]: Model pckage group names list to be added into the collection
collection_name (str): The name of the Collection.
model_groups (List[str]): The names of the Model Groups to add to the Collection.
"""
if len(model_groups) > 10:
raise Exception("Model groups can have a maximum length of 10")
Expand Down Expand Up @@ -268,11 +270,11 @@ def add_model_groups(self, collection_name: str, model_groups: List[str]):
}

def remove_model_groups(self, collection_name: str, model_groups: List[str]):
"""To remove list of model package groups from a collection
"""Removes a list of Model Groups from a Collection.

Args:
collection_name (str): The name of the collection
model_groups List[str]: Model package group names list to be removed
collection_name (str): The name of the Collection.
model_groups (List[str]): The names of the Model Groups to remove.
"""

if len(model_groups) > 10:
Expand Down Expand Up @@ -309,12 +311,12 @@ def remove_model_groups(self, collection_name: str, model_groups: List[str]):
def move_model_group(
self, source_collection_name: str, model_group: str, destination_collection_name: str
):
"""To move a model package group from one collection to another
"""Moves a Model Group from one Collection to another.

Args:
source_collection_name (str): Collection name of the source
model_group (str): Model package group names which is to be moved
destination_collection_name (str): Collection name of the destination
source_collection_name (str): The name of the source Collection.
model_group (str): The name of the Model Group to move.
destination_collection_name (str): The name of the destination Collection.
"""
remove_details = self.remove_model_groups(
collection_name=source_collection_name, model_groups=[model_group]
Expand All @@ -327,7 +329,7 @@ def move_model_group(
)

if len(added_details["failure"]) == 1:
# adding the model group back to the source collection in case of an add failure
# adding the Model Group back to the source collection in case of an add failure
self.add_model_groups(
collection_name=source_collection_name, model_groups=[model_group]
)
Expand All @@ -338,10 +340,10 @@ def move_model_group(
}

def _convert_tag_collection_response(self, tag_collections: List[str]):
"""Converts collection response from tag api to collection list response
"""Converts a Collection response from the tag api to a Collection list response.

Args:
tag_collections List[dict]: Collections list response from tag api
tag_collections List[dict]: The Collection list response from the tag api.
"""
collection_details = []
for collection in tag_collections:
Expand All @@ -359,11 +361,12 @@ def _convert_tag_collection_response(self, tag_collections: List[str]):
def _convert_group_resource_response(
self, group_resource_details: List[dict], is_model_group: bool = False
):
"""Converts collection response from resource group api to collection list response
"""Converts a Collection response from the resource group api to a Collection list response.

Args:
group_resource_details (List[dict]): Collections list response from resource group api
is_model_group (bool): If the reponse is of collection or model group type
group_resource_details (List[dict]): The Collection list response from the
resource group api.
is_model_group (bool): Indicates if the response is of Collection or Model Group type.
"""
collection_details = []
if group_resource_details["Resources"]:
Expand All @@ -382,12 +385,11 @@ def _convert_group_resource_response(
return collection_details

def _get_full_list_resource(self, collection_name, collection_filter):
"""Iterating to the full resource group list and returns appended paginated response
"""Iterates the full resource group list and returns the appended paginated response.

Args:
collection_name (str): Name of the collection to get the details
collection_filter (dict): Filter details to be passed to get the resource list

collection_name (str): The name of the Collection from which to get details.
collection_filter (dict): Filter details to pass to get the resource list.
"""
list_group_response = self.sagemaker_session.list_group_resources(
group=collection_name, filters=collection_filter
Expand All @@ -412,12 +414,13 @@ def _get_full_list_resource(self, collection_name, collection_filter):
return list_group_response

def list_collection(self, collection_name: str = None):
"""To all list the collections and content of the collections
"""Lists the contents of the specified Collection.

In case there is no collection_name, it lists all the collections on the root level
If there is no Collection with the name ``collection_name``, lists all the
Collections at the root level.

Args:
collection_name (str): The name of the collection to list the contents of
collection_name (str): The name of the Collection whose contents are listed.
"""
collection_content = []
if collection_name is None:
Expand Down