Skip to content

Commit aa46a44

Browse files
malav-shastriMalav Shastri
and
Malav Shastri
committed
feat: implement curated hub parser and bug bash fixes (aws#1457)
* implement HubContentDocument parser * modify the parser to remove aliases for hubcontent documents * bug fix * update boto3 * Bug Fix in the parser * Improve Hub Class and related functionalities * Bug Fix and parser updates * add missing hub_arn support * Add model reference deployment support and other minor bug fixes * fix: retrieve correct image_uri (parser update) * fix: retrieve correct model URI and model data path from HubContentDocument (parser update) * Add model reference deployment support * Model accessor and cache retrival bug fixes * fix: curated hub model training workflow * fix: pass sagemaker sessions object to retrieve model specs from describe_hub_content call * fix: fix payload retrieval for curated hub models * modify constants, enums * fix: update parser * Address nits in the parser * Add unit tests for parser * implement pagination for list_models utility * feat: support wildcard chars for model versions * Address nits and comments * Add Hub Content Arn Tag to training and hosting * Add Hub Content Arn Tag to training and hosting * fix: HubContentDocument schema version * fix broken unit tests * fix prepare_container_def unit tests to include ModelReferenceArn * fix unit tests for test_session.py * revert boto version changes * Fix unit tests * support wildcard model versions for training workflow * Add test cases for get_model_versions * Add/fix unit tests --------- Co-authored-by: Malav Shastri <[email protected]>
1 parent fa53c33 commit aa46a44

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+4322
-729
lines changed

src/sagemaker/accept_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,5 @@ def retrieve_default(
123123
region=region,
124124
tolerate_vulnerable_model=tolerate_vulnerable_model,
125125
tolerate_deprecated_model=tolerate_deprecated_model,
126+
sagemaker_session=sagemaker_session
126127
)

src/sagemaker/chainer/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def prepare_container_def(
282282
accelerator_type=None,
283283
serverless_inference_config=None,
284284
accept_eula=None,
285+
model_reference_arn=None
285286
):
286287
"""Return a container definition with framework configuration set in model environment.
287288
@@ -333,6 +334,7 @@ def prepare_container_def(
333334
self.model_data,
334335
deploy_env,
335336
accept_eula=accept_eula,
337+
model_reference_arn=model_reference_arn
336338
)
337339

338340
def serving_image_uri(

src/sagemaker/djl_inference/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ def prepare_container_def(
732732
accelerator_type=None,
733733
serverless_inference_config=None,
734734
accept_eula=None,
735+
model_reference_arn=None
735736
): # pylint: disable=unused-argument
736737
"""A container definition with framework configuration set in model environment variables.
737738

src/sagemaker/huggingface/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def prepare_container_def(
479479
serverless_inference_config=None,
480480
inference_tool=None,
481481
accept_eula=None,
482+
model_reference_arn=None
482483
):
483484
"""A container definition with framework configuration set in model environment variables.
484485
@@ -533,6 +534,7 @@ def prepare_container_def(
533534
self.repacked_model_data or self.model_data,
534535
deploy_env,
535536
accept_eula=accept_eula,
537+
model_reference_arn=model_reference_arn
536538
)
537539

538540
def serving_image_uri(

src/sagemaker/jumpstart/accessors.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from sagemaker.jumpstart import cache
2323
from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs
2424
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
25+
from sagemaker.session import Session
26+
from sagemaker.jumpstart import constants
2527

2628

2729
class SageMakerSettings(object):
@@ -257,6 +259,7 @@ def get_model_specs(
257259
hub_arn: Optional[str] = None,
258260
s3_client: Optional[boto3.client] = None,
259261
model_type=JumpStartModelType.OPEN_WEIGHTS,
262+
sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
260263
) -> JumpStartModelSpecs:
261264
"""Returns model specs from JumpStart models cache.
262265
@@ -272,6 +275,9 @@ def get_model_specs(
272275
if s3_client is not None:
273276
additional_kwargs.update({"s3_client": s3_client})
274277

278+
if hub_arn:
279+
additional_kwargs.update({"sagemaker_session": sagemaker_session})
280+
275281
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
276282
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
277283
)
@@ -282,12 +288,16 @@ def get_model_specs(
282288
hub_model_arn = construct_hub_model_arn_from_inputs(
283289
hub_arn=hub_arn, model_name=model_id, version=version
284290
)
285-
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
291+
model_specs = JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)
292+
model_specs.set_hub_content_type(HubContentType.MODEL)
293+
return model_specs
286294
except:
287295
hub_model_arn = construct_hub_model_reference_arn_from_inputs(
288296
hub_arn=hub_arn, model_name=model_id, version=version
289297
)
290-
return JumpStartModelsAccessor._cache.get_hub_model_reference(hub_model_arn=hub_model_arn)
298+
model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference(hub_model_reference_arn=hub_model_arn)
299+
model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE)
300+
return model_specs
291301

292302
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
293303
model_id=model_id, version_str=version, model_type=model_type

src/sagemaker/jumpstart/artifacts/image_uris.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,11 @@ def _retrieve_image_uri(
130130
)
131131
if image_uri is not None:
132132
return image_uri
133-
ecr_specs = model_specs.hosting_ecr_specs
133+
if hub_arn:
134+
ecr_uri = model_specs.hosting_ecr_uri
135+
return ecr_uri
136+
else:
137+
ecr_specs = model_specs.hosting_ecr_specs
134138
if ecr_specs is None:
135139
raise ValueError(
136140
f"No inference ECR configuration found for JumpStart model ID '{model_id}' "
@@ -145,7 +149,11 @@ def _retrieve_image_uri(
145149
)
146150
if image_uri is not None:
147151
return image_uri
148-
ecr_specs = model_specs.training_ecr_specs
152+
if hub_arn:
153+
ecr_uri = model_specs.training_ecr_uri
154+
return ecr_uri
155+
else:
156+
ecr_specs = model_specs.training_ecr_specs
149157
if ecr_specs is None:
150158
raise ValueError(
151159
f"No training ECR configuration found for JumpStart model ID '{model_id}' "
@@ -198,6 +206,7 @@ def _retrieve_image_uri(
198206
version=version_override or ecr_specs.framework_version,
199207
py_version=ecr_specs.py_version,
200208
instance_type=instance_type,
209+
hub_arn=hub_arn,
201210
accelerator_type=accelerator_type,
202211
image_scope=image_scope,
203212
container_version=container_version,

src/sagemaker/jumpstart/artifacts/model_uris.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,15 @@ def _retrieve_model_uri(
153153

154154
is_prepacked = not model_specs.use_inference_script_uri()
155155

156-
model_artifact_key = (
157-
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
158-
if is_prepacked
159-
else _retrieve_hosting_artifact_key(model_specs, instance_type)
160-
)
156+
if hub_arn:
157+
model_artifact_uri = model_specs.hosting_artifact_uri
158+
return model_artifact_uri
159+
else:
160+
model_artifact_key = (
161+
_retrieve_hosting_prepacked_artifact_key(model_specs, instance_type)
162+
if is_prepacked
163+
else _retrieve_hosting_artifact_key(model_specs, instance_type)
164+
)
161165

162166
elif model_scope == JumpStartScriptScope.TRAINING:
163167

src/sagemaker/jumpstart/cache.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from packaging.version import Version
2323
from packaging.specifiers import SpecifierSet, InvalidSpecifier
2424
from sagemaker.jumpstart.constants import (
25+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2526
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
2627
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
2728
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
@@ -48,17 +49,20 @@
4849
JumpStartModelSpecs,
4950
JumpStartS3FileType,
5051
JumpStartVersionedModelId,
51-
HubType,
5252
HubContentType
5353
)
5454
from sagemaker.jumpstart.hub import utils as hub_utils
5555
from sagemaker.jumpstart.hub.interfaces import (
5656
DescribeHubResponse,
5757
DescribeHubContentResponse,
5858
)
59+
from sagemaker.jumpstart.hub.parsers import (
60+
make_model_specs_from_describe_hub_content_response,
61+
)
5962
from sagemaker.jumpstart.enums import JumpStartModelType
6063
from sagemaker.jumpstart import utils
6164
from sagemaker.utilities.cache import LRUCache
65+
from sagemaker.session import Session
6266

6367

6468
class JumpStartModelsCache:
@@ -84,6 +88,7 @@ def __init__(
8488
s3_bucket_name: Optional[str] = None,
8589
s3_client_config: Optional[botocore.config.Config] = None,
8690
s3_client: Optional[boto3.client] = None,
91+
sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
8792
) -> None:
8893
"""Initialize a ``JumpStartModelsCache`` instance.
8994
@@ -105,6 +110,8 @@ def __init__(
105110
s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache.
106111
Default: None (no config).
107112
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
113+
sagemaker_session: sagemaker session object to use.
114+
Default: session object from default region us-west-2.
108115
"""
109116

110117
self._region = region or utils.get_region_fallback(
@@ -146,6 +153,7 @@ def __init__(
146153
if s3_client_config
147154
else boto3.client("s3", region_name=self._region)
148155
)
156+
self._sagemaker_session = sagemaker_session
149157

150158
def set_region(self, region: str) -> None:
151159
"""Set region for cache. Clears cache after new region is set."""
@@ -453,22 +461,34 @@ def _retrieval_function(
453461
hub_notebook_description = DescribeHubContentResponse(response)
454462
return JumpStartCachedContentValue(formatted_content=hub_notebook_description)
455463

456-
if data_type in [HubContentType.MODEL, HubContentType.MODEL_REFERENCE]:
457-
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
464+
if data_type in {
465+
HubContentType.MODEL,
466+
HubContentType.MODEL_REFERENCE,
467+
}:
468+
469+
hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(
458470
id_info
459471
)
472+
473+
model_version: str = hub_utils.get_hub_model_version(
474+
hub_model_name=hub_arn_extracted_info.hub_content_name,
475+
hub_model_type=data_type.value,
476+
hub_name=hub_arn_extracted_info.hub_name,
477+
sagemaker_session=self._sagemaker_session,
478+
hub_model_version=hub_arn_extracted_info.hub_content_version
479+
)
480+
460481
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
461-
hub_name=hub_name,
462-
hub_content_name=model_name,
482+
hub_name=hub_arn_extracted_info.hub_name,
483+
hub_content_name=hub_arn_extracted_info.hub_content_name,
463484
hub_content_version=model_version,
464-
hub_content_type=data_type,
485+
hub_content_type=data_type.value,
465486
)
466487

467488
model_specs = make_model_specs_from_describe_hub_content_response(
468489
DescribeHubContentResponse(hub_model_description),
469490
)
470491

471-
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
472492
return JumpStartCachedContentValue(formatted_content=model_specs)
473493

474494
raise ValueError(self._file_type_error_msg(data_type))

src/sagemaker/jumpstart/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@
185185

186186
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
187187

188+
JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub"
189+
188190
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
189191
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
190192

src/sagemaker/jumpstart/enums.py

+23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import absolute_import
1616

1717
from enum import Enum
18+
from typing import List
1819

1920

2021
class ModelFramework(str, Enum):
@@ -93,6 +94,7 @@ class JumpStartTag(str, Enum):
9394
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9495
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
9596

97+
HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn"
9698

9799
class SerializerType(str, Enum):
98100
"""Enum class for serializers associated with JumpStart models."""
@@ -124,6 +126,27 @@ def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType":
124126
"""Removes suffix from type and instantiates enum."""
125127
base_type, _, _ = mime_type_with_suffix.partition(";")
126128
return MIMEType(base_type)
129+
130+
class NamingConventionType(str, Enum):
131+
"""Enum class for naming conventions."""
132+
133+
SNAKE_CASE = "snake_case"
134+
UPPER_CAMEL_CASE = "upper_camel_case"
135+
DEFAULT = UPPER_CAMEL_CASE
136+
137+
138+
class ModelSpecKwargType(str, Enum):
139+
"""Enum class for types of kwargs for model hub content document and model specs."""
140+
141+
FIT = "fit_kwargs"
142+
MODEL = "model_kwargs"
143+
ESTIMATOR = "estimator_kwargs"
144+
DEPLOY = "deploy_kwargs"
145+
146+
@classmethod
147+
def arg_keys(cls) -> List[str]:
148+
"""Returns a list of kwargs keys that each type can have"""
149+
return [member.value for member in cls]
127150

128151

129152
class JumpStartConfigRankingName(str, Enum):

src/sagemaker/jumpstart/estimator.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sagemaker.instance_group import InstanceGroup
2929
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
3030
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
31-
from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs
31+
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs
3232
from sagemaker.jumpstart.enums import JumpStartScriptScope
3333
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
3434

@@ -511,28 +511,29 @@ def __init__(
511511
ValueError: If the model ID is not recognized by JumpStart.
512512
"""
513513

514+
hub_arn = None
515+
if hub_name:
516+
hub_arn = generate_hub_arn_for_init_kwargs(
517+
hub_name=hub_name, region=region, session=sagemaker_session
518+
)
519+
514520
def _validate_model_id_and_get_type_hook():
515521
return validate_model_id_and_get_type(
516522
model_id=model_id,
517523
model_version=model_version,
518524
region=region or getattr(sagemaker_session, "boto_region_name", None),
519525
script=JumpStartScriptScope.TRAINING,
520526
sagemaker_session=sagemaker_session,
527+
hub_arn=hub_arn
521528
)
522-
529+
523530
self.model_type = _validate_model_id_and_get_type_hook()
524531
if not self.model_type:
525532
JumpStartModelsAccessor.reset_cache()
526533
self.model_type = _validate_model_id_and_get_type_hook()
527-
if not self.model_type:
534+
if not self.model_type and not hub_arn:
528535
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
529536

530-
hub_arn = None
531-
if hub_name:
532-
hub_arn = generate_hub_arn_for_init_kwargs(
533-
hub_name=hub_name, region=region, session=sagemaker_session
534-
)
535-
536537
estimator_init_kwargs = get_init_kwargs(
537538
model_id=model_id,
538539
model_version=model_version,
@@ -691,6 +692,7 @@ def attach(
691692
training_job_name: str,
692693
model_id: Optional[str] = None,
693694
model_version: Optional[str] = None,
695+
hub_arn: Optional[str] = None,
694696
sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
695697
model_channel_name: str = "model",
696698
) -> "JumpStartEstimator":
@@ -756,6 +758,7 @@ def attach(
756758
model_specs = verify_model_region_and_return_specs(
757759
model_id=model_id,
758760
version=model_version,
761+
hub_arn=hub_arn,
759762
region=sagemaker_session.boto_region_name,
760763
scope=JumpStartScriptScope.TRAINING,
761764
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
@@ -1110,6 +1113,7 @@ def deploy(
11101113
predictor=predictor,
11111114
model_id=self.model_id,
11121115
model_version=self.model_version,
1116+
hub_arn=self.hub_arn,
11131117
region=self.region,
11141118
tolerate_deprecated_model=self.tolerate_deprecated_model,
11151119
tolerate_vulnerable_model=self.tolerate_vulnerable_model,

0 commit comments

Comments
 (0)