diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ff5a82a902..a49cc83ab0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6207,7 +6207,7 @@ def _intercept_create_request( self, request: typing.Dict, create, - func_name: str = None + func_name: str = None, # pylint: disable=unused-argument ): """This function intercepts the create job request. @@ -6470,6 +6470,269 @@ def wait_for_inference_recommendations_job( _check_job_status(job_name, desc, "Status") return desc + def create_hub( + self, + hub_name: str, + hub_description: str, + hub_display_name: str = None, + hub_search_keywords: List[str] = None, + s3_storage_config: Dict[str, Any] = None, + tags: List[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Creates a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to create. + hub_description (str): A description of the Hub. + hub_display_name (str): The display name of the Hub. + hub_search_keywords (list): The searchable keywords for the Hub. + s3_storage_config (S3StorageConfig): The Amazon S3 storage configuration for the Hub. + tags (list): Any tags to associate with the Hub. + + Returns: + (dict): Return value from the ``CreateHub`` API. + """ + request = {"HubName": hub_name, "HubDescription": hub_description} + if hub_display_name: + request["HubDisplayName"] = hub_display_name + if hub_search_keywords: + request["HubSearchKeywords"] = hub_search_keywords + if s3_storage_config: + request["S3StorageConfig"] = s3_storage_config + if tags: + request["Tags"] = tags + + return self.sagemaker_client.create_hub(**request) + + def describe_hub(self, hub_name: str) -> Dict[str, Any]: + """Describes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to describe. + + Returns: + (dict): Return value for ``DescribeHub`` API + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.describe_hub(**request) + + def list_hubs( + self, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists all existing SageMaker Hubs + + Args: + creation_time_after (str): Only list HubContent that was created after + the time specified. + creation_time_before (str): Only list HubContent that was created + before the time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubs`` API + """ + request = {} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hubs(**request) + + def list_hub_contents( + self, + hub_name: str, + hub_content_type: str, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists the HubContents in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to list the contents of. + hub_content_type (str): The type of the HubContent to list. + creation_time_after (str): Only list HubContent that was created after the + time specified. + creation_time_before (str): Only list HubContent that was created before the + time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubContents`` API + """ + request = {"HubName": hub_name, "HubContentType": hub_content_type} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_contents(**request) + + def delete_hub(self, hub_name: str) -> None: + """Deletes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to delete. + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.delete_hub(**request) + + def import_hub_content( + self, + document_schema_version: str, + hub_content_name: str, + hub_content_type: str, + hub_name: str, + hub_content_document: str, + hub_content_display_name: str = None, + hub_content_description: str = None, + hub_content_version: str = None, + hub_content_markdown: str = None, + hub_content_search_keywords: List[str] = None, + tags: List[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Imports a new HubContent into a SageMaker Hub + + Args: + document_schema_version (str): The version of the HubContent schema to import. + hub_content_name (str): The name of the HubContent to import. + hub_content_version (str): The version of the HubContent to import. + hub_content_type (str): The type of HubContent to import. + hub_name (str): The name of the Hub to import content to. + hub_content_document (str): The hub content document that describes information + about the hub content such as type, associated containers, scripts, and more. + hub_content_display_name (str): The display name of the HubContent to import. + hub_content_description (str): The description of the HubContent to import. + hub_content_markdown (str): A string that provides a description of the HubContent. + This string can include links, tables, and standard markdown formatting. + hub_content_search_keywords (list): The searchable keywords of the HubContent. + tags (list): Any tags associated with the HubContent. + Returns: + (dict): Return value for ``ImportHubContent`` API + """ + request = { + "DocumentSchemaVersion": document_schema_version, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + "HubContentDocument": hub_content_document, + } + if hub_content_display_name: + request["HubContentDisplayName"] = hub_content_display_name + if hub_content_description: + request["HubContentDescription"] = hub_content_description + if hub_content_version: + request["HubContentVersion"] = hub_content_version + if hub_content_markdown: + request["HubContentMarkdown"] = hub_content_markdown + if hub_content_search_keywords: + request["HubContentSearchKeywords"] = hub_content_search_keywords + if tags: + request["Tags"] = tags + + return self.sagemaker_client.import_hub_content(**request) + + def describe_hub_content( + self, + hub_content_name: str, + hub_content_type: str, + hub_name: str, + hub_content_version: str = None, + ) -> Dict[str, Any]: + """Describes a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + hub_content_version (str): The version of the HubContent to describe + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + request = { + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + } + if hub_content_version: + request["HubContentVersion"] = hub_content_version + + return self.sagemaker_client.describe_hub_content(**request) + + def delete_hub_content( + self, hub_content_name: str, hub_content_version: str, hub_content_type: str, hub_name: str + ) -> None: + """Deletes a given HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the content thatyou want to delete from a Hub. + hub_content_version (str): The version of the content that you want to delete from + a Hub. + hub_content_type (str): The type of the content that you want to delete from a Hub. + hub_name (str): The name of the Hub that you want to delete content in. + """ + request = { + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + "HubContentVersion": hub_content_version, + } + + return self.sagemaker_client.delete_hub_content(**request) + def get_model_package_args( content_types=None, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index de543b6f53..3bffa35282 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -6529,3 +6529,167 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): Filename="./foo/bar/mode.tar.gz", ExtraArgs=None, ) + + +def test_create_hub(sagemaker_session): + sagemaker_session.create_hub( + hub_name="mock-hub-name", + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + + request = { + "HubName": "mock-hub-name", + "HubDescription": "this is my sagemaker hub", + "HubDisplayName": "Mock Hub", + "HubSearchKeywords": ["mock", "hub", "123"], + "S3StorageConfig": {"S3OutputPath": "s3://my-hub-bucket/"}, + "Tags": [{"Key": "tag-key-1", "Value": "tag-value-1"}], + } + + sagemaker_session.sagemaker_client.create_hub.assert_called_with(**request) + + +def test_describe_hub(sagemaker_session): + sagemaker_session.describe_hub( + hub_name="mock-hub-name", + ) + + request = { + "HubName": "mock-hub-name", + } + + sagemaker_session.sagemaker_client.describe_hub.assert_called_with(**request) + + +def test_list_hubs(sagemaker_session): + sagemaker_session.list_hubs( + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08-2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08-2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hubs.assert_called_with(**request) + + +def test_list_hub_contents(sagemaker_session): + sagemaker_session.list_hub_contents( + hub_name="mock-hub-123", + hub_content_type="MODEL", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODEL", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) + + +def test_delete_hub(sagemaker_session): + sagemaker_session.delete_hub( + hub_name="mock-hub-123", + ) + + request = { + "HubName": "mock-hub-123", + } + + sagemaker_session.sagemaker_client.delete_hub.assert_called_with(**request) + + +def test_import_hub_content(sagemaker_session): + sagemaker_session.import_hub_content( + hub_name="mock-hub-123", + hub_content_type="MODEL", + document_schema_version="1.0.0", + hub_content_document="{'training_script_location':'s3://path/to/script.py'}", + hub_content_name="mock-hub-content-1", + hub_content_display_name="Mock Hub Content One", + hub_content_description="This is my special Hub Content for my special Hub", + hub_content_version="5.5.5", + hub_content_markdown="markdown", + hub_content_search_keywords=["Hub", "Machine Learning", "Content"], + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODEL", + "DocumentSchemaVersion": "1.0.0", + "HubContentDocument": "{'training_script_location':'s3://path/to/script.py'}", + "HubContentName": "mock-hub-content-1", + "HubContentDisplayName": "Mock Hub Content One", + "HubContentDescription": "This is my special Hub Content for my special Hub", + "HubContentVersion": "5.5.5", + "HubContentMarkdown": "markdown", + "HubContentSearchKeywords": ["Hub", "Machine Learning", "Content"], + } + + sagemaker_session.sagemaker_client.import_hub_content.assert_called_with(**request) + + +def test_describe_hub_content(sagemaker_session): + sagemaker_session.describe_hub_content( + hub_name="mock-hub-123", + hub_content_type="MODEL", + hub_content_name="mock-hub-content-1", + hub_content_version="5.5.5", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODEL", + "HubContentName": "mock-hub-content-1", + "HubContentVersion": "5.5.5", + } + + sagemaker_session.sagemaker_client.describe_hub_content.assert_called_with(**request) + + +def test_delete_hub_content(sagemaker_session): + sagemaker_session.delete_hub_content( + hub_name="mock-hub-123", + hub_content_type="MODEL", + hub_content_name="mock-hub-content-1", + hub_content_version="5.5.5", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODEL", + "HubContentName": "mock-hub-content-1", + "HubContentVersion": "5.5.5", + } + + sagemaker_session.sagemaker_client.delete_hub_content.assert_called_with(**request)