diff --git a/doc/api/inference/model_collection.rst b/doc/api/inference/model_collection.rst new file mode 100644 index 0000000000..7281350ae6 --- /dev/null +++ b/doc/api/inference/model_collection.rst @@ -0,0 +1,8 @@ +Model Collection +---------------- + +.. automodule:: sagemaker.collection + :members: + :undoc-members: + :show-inheritance: + diff --git a/src/sagemaker/collection.py b/src/sagemaker/collection.py index 7703b14b4d..7633085506 100644 --- a/src/sagemaker/collection.py +++ b/src/sagemaker/collection.py @@ -11,9 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module contains code related to Amazon SageMaker Collection. +"""This module contains code related to Amazon SageMaker Collections in the Model Registry. -These Classes helps in providing features to maintain and create collections +Use these methods to help you create and maintain your Collections. """ from __future__ import absolute_import @@ -27,27 +27,29 @@ class Collection(object): - """Sets up Amazon SageMaker Collection.""" + """Sets up an Amazon SageMaker Collection.""" def __init__(self, sagemaker_session): """Initializes a Collection instance. - The collection provides a logical grouping for model groups + A Collection is a logical grouping of Model Groups. Args: - sagemaker_session (sagemaker.session.Session): Session object which - manages interactions with Amazon SageMaker APIs and any other - AWS services needed. If not specified, one is created using + sagemaker_session (sagemaker.session.Session): A Session object which + manages interactions between Amazon SageMaker APIs and other + AWS services needed. If unspecified, a session is created using the default AWS configuration chain. """ + self.sagemaker_session = sagemaker_session or Session() def _check_access_error(self, err: ClientError): - """To check if the error is related to the access error and to provide the relavant message + """Checks if the error is related to the access error and provide the relevant message. Args: - err: The client error that needs to be checked + err: The client error to check. """ + error_code = err.response["Error"]["Code"] if error_code == "AccessDeniedException": raise Exception( @@ -57,12 +59,12 @@ def _check_access_error(self, err: ClientError): ) def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value): - """To add a model package group to a collection + """Adds a Model Group to a Collection. Args: - model_package_group (str): The name of the model package group - tag_rule_key (str): The tag key of the corresponing collection to be added into - tag_rule_value (str): The tag value of the corresponing collection to be added into + model_package_group (str): The name of the Model Group. + tag_rule_key (str): The tag key of the destination collection. + tag_rule_value (str): The tag value of the destination collection. """ model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group( ModelPackageGroupName=model_package_group @@ -78,11 +80,11 @@ def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value): ) def _remove_model_group(self, model_package_group, tag_rule_key): - """To remove a model package group from a collection + """Removes a Model Group from a Collection. Args: - model_package_group (str): The name of the model package group - tag_rule_key (str): The tag key of the corresponing collection to be removed from + model_package_group (str): The name of the Model Group + tag_rule_key (str): The tag key of the Collection from which to remove the Model Group. """ model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group( ModelPackageGroupName=model_package_group @@ -92,12 +94,12 @@ def _remove_model_group(self, model_package_group, tag_rule_key): ) def create(self, collection_name: str, parent_collection_name: str = None): - """Creates a collection + """Creates a Collection. Args: - collection_name (str): The name of the collection to be created - parent_collection_name (str): The name of the parent collection. - To be None if the collection is to be created on the root level + collection_name (str): The name of the Collection to create. + parent_collection_name (str): The name of the parent Collection. + Is ``None`` if the Collection is created at the root level. """ tag_rule_key = f"sagemaker:collection-path:{int(time.time() * 1000)}" @@ -151,11 +153,11 @@ def create(self, collection_name: str, parent_collection_name: str = None): raise def delete(self, collections: List[str]): - """Deletes a list of collection. + """Deletes a list of Collections. Args: - collections (List[str]): List of collections to be deleted - Only deletes a collection if it is empty + collections (List[str]): A list of Collections to delete. + Only deletes a Collection if it is empty. """ if len(collections) > 10: @@ -201,7 +203,7 @@ def delete(self, collections: List[str]): } def _get_collection_tag_rule(self, collection_name: str): - """Returns the tag rule key and value for a collection""" + """Returns the tag rule key and value for a Collection.""" if collection_name is not None: try: @@ -230,11 +232,11 @@ def _get_collection_tag_rule(self, collection_name: str): raise ValueError("Collection name is required") def add_model_groups(self, collection_name: str, model_groups: List[str]): - """To add list of model package groups to a collection + """Adds a list of Model Groups to a Collection. Args: - collection_name (str): The name of the collection - model_groups List[str]: Model pckage group names list to be added into the collection + collection_name (str): The name of the Collection. + model_groups (List[str]): The names of the Model Groups to add to the Collection. """ if len(model_groups) > 10: raise Exception("Model groups can have a maximum length of 10") @@ -268,11 +270,11 @@ def add_model_groups(self, collection_name: str, model_groups: List[str]): } def remove_model_groups(self, collection_name: str, model_groups: List[str]): - """To remove list of model package groups from a collection + """Removes a list of Model Groups from a Collection. Args: - collection_name (str): The name of the collection - model_groups List[str]: Model package group names list to be removed + collection_name (str): The name of the Collection. + model_groups (List[str]): The names of the Model Groups to remove. """ if len(model_groups) > 10: @@ -309,12 +311,12 @@ def remove_model_groups(self, collection_name: str, model_groups: List[str]): def move_model_group( self, source_collection_name: str, model_group: str, destination_collection_name: str ): - """To move a model package group from one collection to another + """Moves a Model Group from one Collection to another. Args: - source_collection_name (str): Collection name of the source - model_group (str): Model package group names which is to be moved - destination_collection_name (str): Collection name of the destination + source_collection_name (str): The name of the source Collection. + model_group (str): The name of the Model Group to move. + destination_collection_name (str): The name of the destination Collection. """ remove_details = self.remove_model_groups( collection_name=source_collection_name, model_groups=[model_group] @@ -327,7 +329,7 @@ def move_model_group( ) if len(added_details["failure"]) == 1: - # adding the model group back to the source collection in case of an add failure + # adding the Model Group back to the source collection in case of an add failure self.add_model_groups( collection_name=source_collection_name, model_groups=[model_group] ) @@ -338,10 +340,10 @@ def move_model_group( } def _convert_tag_collection_response(self, tag_collections: List[str]): - """Converts collection response from tag api to collection list response + """Converts a Collection response from the tag api to a Collection list response. Args: - tag_collections List[dict]: Collections list response from tag api + tag_collections List[dict]: The Collection list response from the tag api. """ collection_details = [] for collection in tag_collections: @@ -359,11 +361,12 @@ def _convert_tag_collection_response(self, tag_collections: List[str]): def _convert_group_resource_response( self, group_resource_details: List[dict], is_model_group: bool = False ): - """Converts collection response from resource group api to collection list response + """Converts a Collection response from the resource group api to a Collection list response. Args: - group_resource_details (List[dict]): Collections list response from resource group api - is_model_group (bool): If the reponse is of collection or model group type + group_resource_details (List[dict]): The Collection list response from the + resource group api. + is_model_group (bool): Indicates if the response is of Collection or Model Group type. """ collection_details = [] if group_resource_details["Resources"]: @@ -382,12 +385,11 @@ def _convert_group_resource_response( return collection_details def _get_full_list_resource(self, collection_name, collection_filter): - """Iterating to the full resource group list and returns appended paginated response + """Iterates the full resource group list and returns the appended paginated response. Args: - collection_name (str): Name of the collection to get the details - collection_filter (dict): Filter details to be passed to get the resource list - + collection_name (str): The name of the Collection from which to get details. + collection_filter (dict): Filter details to pass to get the resource list. """ list_group_response = self.sagemaker_session.list_group_resources( group=collection_name, filters=collection_filter @@ -412,12 +414,13 @@ def _get_full_list_resource(self, collection_name, collection_filter): return list_group_response def list_collection(self, collection_name: str = None): - """To all list the collections and content of the collections + """Lists the contents of the specified Collection. - In case there is no collection_name, it lists all the collections on the root level + If there is no Collection with the name ``collection_name``, lists all the + Collections at the root level. Args: - collection_name (str): The name of the collection to list the contents of + collection_name (str): The name of the Collection whose contents are listed. """ collection_content = [] if collection_name is None: