Skip to content

Commit 3aad76b

Browse files
committed
update session and add tests
1 parent 058d855 commit 3aad76b

File tree

2 files changed

+258
-56
lines changed

2 files changed

+258
-56
lines changed

src/sagemaker/session.py

+98-56
Original file line numberDiff line numberDiff line change
@@ -6485,14 +6485,20 @@ def create_hub(
64856485
Returns:
64866486
(dict): Return value from the ``CreateHub`` API.
64876487
"""
6488-
return self.sagemaker_client.create_hub(
6489-
hub_name=hub_name,
6490-
hub_description=hub_description,
6491-
hub_display_name=hub_display_name,
6492-
hub_search_keywords=hub_search_keywords,
6493-
s3_storage_config=s3_storage_config,
6494-
tags=tags
6495-
)
6488+
request = {
6489+
"HubName": hub_name,
6490+
"HubDescription": hub_description
6491+
}
6492+
if hub_display_name:
6493+
request["HubDisplayName"] = hub_display_name
6494+
if hub_search_keywords:
6495+
request["HubSearchKeywords"] = hub_search_keywords
6496+
if s3_storage_config:
6497+
request["S3StorageConfig"] = s3_storage_config
6498+
if tags:
6499+
request["Tags"] = tags
6500+
6501+
return self.sagemaker_client.create_hub(**request)
64966502

64976503
def describe_hub(
64986504
self,
@@ -6506,9 +6512,9 @@ def describe_hub(
65066512
Returns:
65076513
(dict): Return value for ``DescribeHub`` API
65086514
"""
6509-
return self.sagemaker_client.describe_hub(
6510-
hub_name=hub_name
6511-
)
6515+
request = { "HubName": hub_name }
6516+
6517+
return self.sagemaker_client.describe_hub(**request)
65126518

65136519
def list_hubs(
65146520
self,
@@ -6535,15 +6541,23 @@ def list_hubs(
65356541
Returns:
65366542
(dict): Return value for ``ListHubs`` API
65376543
"""
6538-
return self.sagemaker_client.list_hubs(
6539-
creation_time_after=creation_time_after,
6540-
creation_time_before=creation_time_before,
6541-
max_results=max_results,
6542-
max_schema_version=max_schema_version,
6543-
name_contains=name_contains,
6544-
sort_by=sort_by,
6545-
sort_order=sort_order
6546-
)
6544+
request = {}
6545+
if creation_time_after:
6546+
request["CreationTimeAfter"] = creation_time_after
6547+
if creation_time_before:
6548+
request["CreationTimeBefore"] = creation_time_before
6549+
if max_results:
6550+
request["MaxResults"] = max_results
6551+
if max_schema_version:
6552+
request["MaxSchemaVersion"] = max_schema_version
6553+
if name_contains:
6554+
request["NameContains"] = name_contains
6555+
if sort_by:
6556+
request["SortBy"] = sort_by
6557+
if sort_order:
6558+
request["SortOrder"] = sort_order
6559+
6560+
return self.sagemaker_client.list_hubs(*request)
65476561

65486562
def list_hub_contents(
65496563
self,
@@ -6574,17 +6588,26 @@ def list_hub_contents(
65746588
Returns:
65756589
(dict): Return value for ``ListHubContents`` API
65766590
"""
6577-
return self.sagemaker_client.list_hub_contents(
6578-
hub_name=hub_name,
6579-
hub_content_type=hub_content_type,
6580-
creation_time_after=creation_time_after,
6581-
creation_time_before=creation_time_before,
6582-
max_results=max_results,
6583-
max_schema_version=max_schema_version,
6584-
name_contains=name_contains,
6585-
sort_by=sort_by,
6586-
sort_order=sort_order
6587-
)
6591+
request = {
6592+
"HubName": hub_name,
6593+
"HubContentType": hub_content_type
6594+
}
6595+
if creation_time_after:
6596+
request["CreationTimeAfter"] = creation_time_after
6597+
if creation_time_before:
6598+
request["CreationTimeBefore"] = creation_time_before
6599+
if max_results:
6600+
request["MaxResults"] = max_results
6601+
if max_schema_version:
6602+
request["MaxSchemaVersion"] = max_schema_version
6603+
if name_contains:
6604+
request["NameContains"] = name_contains
6605+
if sort_by:
6606+
request["SortBy"] = sort_by
6607+
if sort_order:
6608+
request["SortOrder"] = sort_order
6609+
6610+
return self.sagemaker_client.list_hub_contents(*request)
65886611

65896612
def delete_hub(
65906613
self,
@@ -6595,14 +6618,17 @@ def delete_hub(
65956618
Args:
65966619
hub_name (str): The name of the hub to delete.
65976620
"""
6598-
return self.sagemaker_client.delete_hub(hub_name=hub_name)
6621+
request = { "HubName": hub_name }
6622+
6623+
return self.sagemaker_client.delete_hub(*request)
65996624

66006625
def import_hub_content(
66016626
self,
66026627
document_schema_version: str,
66036628
hub_content_name: str,
66046629
hub_content_type: str,
66056630
hub_name: str,
6631+
hub_content_document: str,
66066632
hub_content_display_name=None,
66076633
hub_content_description=None,
66086634
hub_content_version=None,
@@ -6618,6 +6644,8 @@ def import_hub_content(
66186644
hub_content_version (str): The version of the HubContent to import.
66196645
hub_content_type (str): The type of HubContent to import.
66206646
hub_name (str): The name of the Hub to import content to.
6647+
hub_content_document (str): The hub content document that describes information about the hub content
6648+
such as type, associated containers, scripts, and more.
66216649
hub_content_display_name (str): The display name of the HubContent to import.
66226650
hub_content_description (str): The description of the HubContent to import.
66236651
hub_content_markdown (str): A string that provides a description of the HubContent. This string can include links, tables,
@@ -6627,18 +6655,27 @@ def import_hub_content(
66276655
Returns:
66286656
(dict): Return value for ``ImportHubContent`` API
66296657
"""
6630-
return self.sagemaker_client.import_hub_content(
6631-
document_schema_version=document_schema_version,
6632-
hub_content_name=hub_content_name,
6633-
hub_content_version=hub_content_version,
6634-
hub_content_type=hub_content_type,
6635-
hub_name=hub_name,
6636-
hub_content_display_name=hub_content_display_name,
6637-
hub_content_description=hub_content_description,
6638-
hub_content_markdown=hub_content_markdown,
6639-
hub_content_search_keywords=hub_content_search_keywords,
6640-
tags=tags
6641-
)
6658+
request = {
6659+
"DocumentSchemaVersion": document_schema_version,
6660+
"HubContentName": hub_content_name,
6661+
"HubContentType": hub_content_type,
6662+
"HubName": hub_name,
6663+
"HubContentDocument": hub_content_document
6664+
}
6665+
if hub_content_display_name:
6666+
request["HubContentDisplayName"] = hub_content_display_name
6667+
if hub_content_description:
6668+
request["HubContentDescription"] = hub_content_description
6669+
if hub_content_version:
6670+
request["HubContentVersion"] = hub_content_version
6671+
if hub_content_markdown:
6672+
request["HubContentMarkdown"] = hub_content_markdown
6673+
if hub_content_search_keywords:
6674+
request["HubContentSearchKeywords"] = hub_content_search_keywords
6675+
if tags:
6676+
request["Tags"] = tags
6677+
6678+
return self.sagemaker_client.import_hub_content(*request)
66426679

66436680
def describe_hub_content(
66446681
self,
@@ -6658,12 +6695,15 @@ def describe_hub_content(
66586695
Returns:
66596696
(dict): Return value for ``DescribeHubContent`` API
66606697
"""
6661-
return self.sagemaker_client.describe_hub_content(
6662-
hub_content_name=hub_content_name,
6663-
hub_content_type=hub_content_type,
6664-
hub_name=hub_name,
6665-
hub_content_version=hub_content_version
6666-
)
6698+
request = {
6699+
"HubContentName": hub_content_name,
6700+
"HubContentType": hub_content_type,
6701+
"HubName": hub_name
6702+
}
6703+
if hub_content_version:
6704+
request["HubContentVersion"] = hub_content_version
6705+
6706+
return self.sagemaker_client.describe_hub_content(*request)
66676707

66686708
def delete_hub_content(
66696709
self,
@@ -6680,12 +6720,14 @@ def delete_hub_content(
66806720
hub_content_type (str): The type of the content that you want to delete from a Hub.
66816721
hub_name (str): The name of the Hub that you want to delete content in.
66826722
"""
6683-
return self.sagemaker_client.delete_hub_content(
6684-
hub_content_name=hub_content_name,
6685-
hub_content_version=hub_content_version,
6686-
hub_content_type=hub_content_type,
6687-
hub_name=hub_name
6688-
)
6723+
request = {
6724+
"HubContentName": hub_content_name,
6725+
"HubContentType": hub_content_type,
6726+
"HubName": hub_name,
6727+
"HubContentVersion": hub_content_version
6728+
}
6729+
6730+
return self.sagemaker_client.delete_hub_content(*request)
66896731

66906732

66916733
def get_model_package_args(

tests/unit/test_session.py

+160
Original file line numberDiff line numberDiff line change
@@ -6489,3 +6489,163 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session):
64896489
Filename="./foo/bar/mode.tar.gz",
64906490
ExtraArgs=None,
64916491
)
6492+
6493+
def test_create_hub(sagemaker_session):
6494+
sagemaker_session.create_hub(
6495+
hub_name="mock-hub-name",
6496+
hub_description="this is my sagemaker hub",
6497+
hub_display_name="Mock Hub",
6498+
hub_search_keywords=["mock", "hub", "123"],
6499+
s3_storage_config={
6500+
"S3OutputPath": "s3://my-hub-bucket/"
6501+
},
6502+
tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}]
6503+
)
6504+
6505+
request = {
6506+
"hub_name": "mock-hub-name",
6507+
"hub_description": "this is my sagemaker hub",
6508+
"hub_display_name": "Mock Hub",
6509+
"hub_search_keywords": ["mock", "hub", "123"],
6510+
"s3_storage_config": {
6511+
"S3OutputPath": "s3://my-hub-bucket/"
6512+
},
6513+
"tags": [{"Key": "tag-key-1", "Value": "tag-value-1"}]
6514+
}
6515+
6516+
sagemaker_session.create_hub.assert_called_with(**request)
6517+
6518+
def test_describe_hub(sagemaker_session):
6519+
sagemaker_session.describe_hub(
6520+
hub_name="mock-hub-name",
6521+
)
6522+
6523+
request = {
6524+
"hub_name": "mock-hub-name",
6525+
}
6526+
6527+
sagemaker_session.describe_hub.assert_called_with(**request)
6528+
6529+
def test_list_hubs(sagemaker_session):
6530+
sagemaker_session.list_hubs(
6531+
creation_time_after="08-14-1997 12:00:00",
6532+
creation_time_before="01-08/2024 10:25:00",
6533+
max_results="25",
6534+
max_schema_version="1.0.5",
6535+
name_contains="mock-hub",
6536+
sort_by="HubName",
6537+
sort_order="Ascending"
6538+
)
6539+
6540+
request = {
6541+
"creation_time_after": "08-14-1997 12:00:00",
6542+
"creation_time_before": "01-08/2024 10:25:00",
6543+
"max_results": "25",
6544+
"max_schema_version": "1.0.5",
6545+
"name_contains": "mock-hub",
6546+
"sort_by": "HubName",
6547+
"sort_order": "Ascending"
6548+
}
6549+
6550+
sagemaker_session.list_hubs.assert_called_with(**request)
6551+
6552+
def test_list_hub_contents(sagemaker_session):
6553+
sagemaker_session.list_hub_contents(
6554+
hub_name="mock-hub-123",
6555+
hub_content_type="MODEL",
6556+
creation_time_after="08-14-1997 12:00:00",
6557+
creation_time_before="01-08/2024 10:25:00",
6558+
max_results="25",
6559+
max_schema_version="1.0.5",
6560+
name_contains="mock-hub",
6561+
sort_by="HubName",
6562+
sort_order="Ascending"
6563+
)
6564+
6565+
request = {
6566+
"hub_name": "mock-hub-123",
6567+
"hub_content_type": "MODEL",
6568+
"creation_time_after": "08-14-1997 12:00:00",
6569+
"creation_time_before": "01-08/2024 10:25:00",
6570+
"max_results": "25",
6571+
"max_schema_version": "1.0.5",
6572+
"name_contains": "mock-hub",
6573+
"sort_by": "HubName",
6574+
"sort_order": "Ascending"
6575+
}
6576+
6577+
sagemaker_session.list_hub_contents.assert_called_with(**request)
6578+
6579+
def test_delete_hub(sagemaker_session):
6580+
sagemaker_session.delete_hub(
6581+
hub_name="mock-hub-123",
6582+
)
6583+
6584+
request = {
6585+
"hub_name": "mock-hub-123",
6586+
}
6587+
6588+
sagemaker_session.delete_hub.assert_called_with(**request)
6589+
6590+
def test_import_hub_content(sagemaker_session):
6591+
sagemaker_session.import_hub_content(
6592+
hub_name="mock-hub-123",
6593+
hub_content_type="MODEL",
6594+
document_schema_version="1.0.0",
6595+
hub_content_document="{'training_script_location':'s3://path/to/script.py'}",
6596+
hub_content_name="mock-hub-content-1",
6597+
hub_content_display_name="Mock Hub Content One",
6598+
hub_content_description="This is my special Hub Content for my special Hub",
6599+
hub_content_version="5.5.5",
6600+
hub_content_markdown="markdown",
6601+
hub_content_search_keywords=["Hub","Machine Learning","Content"]
6602+
)
6603+
6604+
request = {
6605+
"hub_name": "mock-hub-123",
6606+
"hub_content_type": "MODEL",
6607+
"document_schema_version": "1.0.0",
6608+
"hub_content_document": "{'training_script_location':'s3://path/to/script.py'}",
6609+
"hub_content_name": "mock-hub-content-1",
6610+
"hub_content_display_name": "Mock Hub Content One",
6611+
"hub_content_description": "This is my special Hub Content for my special Hub",
6612+
"hub_content_version": "5.5.5",
6613+
"hub_content_markdown": "markdown",
6614+
"hub_content_search_keywords": ["Hub","Machine Learning","Content"]
6615+
}
6616+
6617+
sagemaker_session.import_hub_content.assert_called_with(**request)
6618+
6619+
def test_describe_hub_content(sagemaker_session):
6620+
sagemaker_session.describe_hub_content(
6621+
hub_name="mock-hub-123",
6622+
hub_content_type="MODEL",
6623+
hub_content_name="mock-hub-content-1",
6624+
hub_content_version="5.5.5",
6625+
)
6626+
6627+
request = {
6628+
"hub_name": "mock-hub-123",
6629+
"hub_content_type": "MODEL",
6630+
"hub_content_name": "mock-hub-content-1",
6631+
"hub_content_version": "5.5.5",
6632+
}
6633+
6634+
sagemaker_session.describe_hub_content.assert_called_with(**request)
6635+
6636+
def test_delete_hub_content(sagemaker_session):
6637+
sagemaker_session.delete_hub_content(
6638+
hub_name="mock-hub-123",
6639+
hub_content_type="MODEL",
6640+
hub_content_name="mock-hub-content-1",
6641+
hub_content_version="5.5.5",
6642+
)
6643+
6644+
request = {
6645+
"hub_name": "mock-hub-123",
6646+
"hub_content_type": "MODEL",
6647+
"hub_content_name": "mock-hub-content-1",
6648+
"hub_content_version": "5.5.5",
6649+
}
6650+
6651+
sagemaker_session.delete_hub_content.assert_called_with(**request)

0 commit comments

Comments
 (0)