-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: Marketplace model support in HubService #4916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0092ff4
8b0ec90
273449c
a8a2453
d4430e2
cd82335
0b73463
3ca1deb
4bdd822
79a1163
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,12 +19,11 @@ | |
|
||
|
||
def camel_to_snake(camel_case_string: str) -> str: | ||
"""Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" | ||
snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) | ||
if "-" in snake_case_string: | ||
# remove any hyphen from the string for accurate conversion. | ||
snake_case_string = snake_case_string.replace("-", "") | ||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() | ||
"""Converts PascalCase to snake_case_string using a regex. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent regex and method name. Careful with renaming the method though, it would be backward incompatible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True, let me rewrite that docstring. This regex can handle camelCase as well as PascalCase |
||
|
||
This regex cannot handle whitespace ("PascalString TwoWords") | ||
""" | ||
return re.sub(r"(?<!^)(?=[A-Z])", "_", camel_case_string).lower() | ||
|
||
|
||
def snake_to_upper_camel(snake_case_string: str) -> str: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -856,7 +856,16 @@ def validate_model_id_and_get_type( | |
if not isinstance(model_id, str): | ||
return None | ||
if hub_arn: | ||
return None | ||
model_types = _validate_hub_service_model_id_and_get_type( | ||
model_id=model_id, | ||
hub_arn=hub_arn, | ||
region=region, | ||
model_version=model_version, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
return ( | ||
model_types[0] if model_types else None | ||
) # Currently this function only supports one model type | ||
|
||
s3_client = sagemaker_session.s3_client if sagemaker_session else None | ||
region = region or constants.JUMPSTART_DEFAULT_REGION_NAME | ||
|
@@ -881,6 +890,37 @@ def validate_model_id_and_get_type( | |
return None | ||
|
||
|
||
def _validate_hub_service_model_id_and_get_type( | ||
model_id: Optional[str], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just curious, what happens if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah that should be just |
||
hub_arn: str, | ||
region: Optional[str] = None, | ||
model_version: Optional[str] = None, | ||
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, | ||
) -> List[enums.JumpStartModelType]: | ||
"""Returns a list of JumpStartModelType based off the HubContent. | ||
|
||
Only returns valid JumpStartModelType. Returns an empty array if none are found. | ||
""" | ||
hub_model_specs = accessors.JumpStartModelsAccessor.get_model_specs( | ||
region=region, | ||
model_id=model_id, | ||
version=model_version, | ||
hub_arn=hub_arn, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
hub_content_model_types = [] | ||
model_types_field: Optional[List[str]] = getattr(hub_model_specs, "model_types", []) | ||
model_types = model_types_field if model_types_field else [] | ||
for model_type in model_types: | ||
try: | ||
hub_content_model_types.append(enums.JumpStartModelType[model_type]) | ||
except ValueError: | ||
continue | ||
|
||
return hub_content_model_types | ||
|
||
|
||
def _extract_value_from_list_of_tags( | ||
tag_keys: List[str], | ||
list_tags_result: List[str], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,23 +53,18 @@ def get_sm_session() -> Session: | |
return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)) | ||
|
||
|
||
# def get_sm_session_with_override() -> Session: | ||
# # [TODO]: Remove service endpoint override before GA | ||
# # boto3.set_stream_logger(name='botocore', level=logging.DEBUG) | ||
# boto_session = boto3.Session(region_name="us-west-2") | ||
# sagemaker = boto3.client( | ||
# service_name="sagemaker-internal", | ||
# endpoint_url="https://sagemaker.beta.us-west-2.ml-platform.aws.a2z.com", | ||
# ) | ||
# sagemaker_runtime = boto3.client( | ||
# service_name="runtime.maeve", | ||
# endpoint_url="https://maeveruntime.beta.us-west-2.ml-platform.aws.a2z.com", | ||
# ) | ||
# return Session( | ||
# boto_session=boto_session, | ||
# sagemaker_client=sagemaker, | ||
# sagemaker_runtime_client=sagemaker_runtime, | ||
# ) | ||
def get_sm_session_with_override() -> Session: | ||
# [TODO]: Remove service endpoint override before GA | ||
# boto3.set_stream_logger(name='botocore', level=logging.DEBUG) | ||
boto_session = boto3.Session(region_name="us-west-2") | ||
sagemaker = boto3.client( | ||
service_name="sagemaker", | ||
endpoint_url="https://sagemaker.gamma.us-west-2.ml-platform.aws.a2z.com", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you merge this to Btw this code should never have been pushed to the public repo... |
||
) | ||
return Session( | ||
boto_session=boto_session, | ||
sagemaker_client=sagemaker, | ||
) | ||
|
||
|
||
def get_training_dataset_for_model_and_version(model_id: str, version: str) -> dict: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sanity-check: will the pySDK still understand previous schema?