Skip to content

Commit 7ec4828

Browse files
jinyoung-limbencrabtree
authored andcommitted
feature: JumpStart CuratedHub class creation and function definitions (aws#4448)
1 parent 83a3020 commit 7ec4828

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

src/sagemaker/jumpstart/types.py

+143
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,149 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
11511151
)
11521152
self.model_subscription_link = json_obj.get("model_subscription_link")
11531153

1154+
def from_describe_hub_content_response(self, response: DescribeHubContentResponse) -> None:
1155+
"""Sets fields in object based on values in HubContentDocument
1156+
1157+
Args:
1158+
hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
1159+
from SageMaker:DescribeHubContent
1160+
"""
1161+
self.model_id: str = response.hub_content_name
1162+
self.version: str = response.hub_content_version
1163+
hub_content_document: HubModelDocument = response.hub_content_document
1164+
self.url: str = hub_content_document.url
1165+
self.min_sdk_version: str = hub_content_document.min_sdk_version
1166+
self.training_supported: bool = hub_content_document.training_supported
1167+
self.incremental_training_supported: bool = bool(
1168+
hub_content_document["IncrementalTrainingSupported"]
1169+
)
1170+
self.hosting_ecr_uri: Optional[str] = hub_content_document.hosting_ecr_uri
1171+
self._non_serializable_slots.append("hosting_ecr_specs")
1172+
1173+
hosting_artifact_bucket, hosting_artifact_key = parse_s3_url(
1174+
hub_content_document.hosting_artifact_uri
1175+
)
1176+
self.hosting_artifact_key: str = hosting_artifact_key
1177+
hosting_script_bucket, hosting_script_key = parse_s3_url(
1178+
hub_content_document.hosting_script_uri
1179+
)
1180+
self.hosting_script_key: str = hosting_script_key
1181+
self.inference_environment_variables = hub_content_document.inference_environment_variables
1182+
self.inference_vulnerable: bool = False
1183+
self.inference_dependencies: List[str] = hub_content_document.inference_dependencies
1184+
self.inference_vulnerabilities: List[str] = []
1185+
self.training_vulnerable: bool = False
1186+
self.training_dependencies: List[str] = hub_content_document.training_dependencies
1187+
self.training_vulnerabilities: List[str] = []
1188+
self.deprecated: bool = False
1189+
self.deprecated_message: Optional[str] = None
1190+
self.deprecate_warn_message: Optional[str] = None
1191+
self.usage_info_message: Optional[str] = None
1192+
self.default_inference_instance_type: Optional[
1193+
str
1194+
] = hub_content_document.default_inference_instance_type
1195+
self.default_training_instance_type: Optional[
1196+
str
1197+
] = hub_content_document.default_training_instance_type
1198+
self.supported_inference_instance_types: Optional[
1199+
List[str]
1200+
] = hub_content_document.supported_inference_instance_types
1201+
self.supported_training_instance_types: Optional[
1202+
List[str]
1203+
] = hub_content_document.supported_training_instance_types
1204+
self.dynamic_container_deployment_supported: Optional[
1205+
bool
1206+
] = hub_content_document.dynamic_container_deployment_supported
1207+
self.hosting_resource_requirements: Optional[
1208+
Dict[str, int]
1209+
] = hub_content_document.hosting_resource_requirements
1210+
self.metrics: Optional[List[Dict[str, str]]] = hub_content_document.training_metrics
1211+
self.training_prepacked_script_key: Optional[str] = None
1212+
if hub_content_document.training_prepacked_script_uri is not None:
1213+
training_prepacked_script_bucket, training_prepacked_script_key = parse_s3_url(
1214+
hub_content_document.training_prepacked_script_uri
1215+
)
1216+
self.training_prepacked_script_key = training_prepacked_script_key
1217+
1218+
self.hosting_prepacked_artifact_key: Optional[str] = None
1219+
if hub_content_document.hosting_prepacked_artifact_uri is not None:
1220+
hosting_prepacked_artifact_bucket, hosting_prepacked_artifact_key = parse_s3_url(
1221+
hub_content_document.hosting_prepacked_artifact_uri
1222+
)
1223+
self.hosting_prepacked_artifact_key = hosting_prepacked_artifact_key
1224+
1225+
self.fit_kwargs = get_model_spec_kwargs_from_hub_content_document(
1226+
ModelSpecKwargType.FIT, hub_content_document
1227+
)
1228+
self.model_kwargs = get_model_spec_kwargs_from_hub_content_document(
1229+
ModelSpecKwargType.MODEL, hub_content_document
1230+
)
1231+
self.deploy_kwargs = get_model_spec_kwargs_from_hub_content_document(
1232+
ModelSpecKwargType.DEPLOY, hub_content_document
1233+
)
1234+
self.estimator_kwargs = get_model_spec_kwargs_from_hub_content_document(
1235+
ModelSpecKwargType.ESTIMATOR, hub_content_document
1236+
)
1237+
1238+
self.predictor_specs: Optional[
1239+
JumpStartPredictorSpecs
1240+
] = hub_content_document.sage_maker_sdk_predictor_specifications
1241+
self.default_payloads: Optional[
1242+
Dict[str, JumpStartSerializablePayload]
1243+
] = hub_content_document.default_payloads
1244+
self.gated_bucket = hub_content_document.gated_bucket
1245+
self.inference_volume_size: Optional[int] = hub_content_document.inference_volume_size
1246+
self.inference_enable_network_isolation: bool = (
1247+
hub_content_document.inference_enable_network_isolation
1248+
)
1249+
self.resource_name_base: Optional[str] = hub_content_document.resource_name_base
1250+
1251+
self.hosting_eula_key: Optional[str] = None
1252+
if hub_content_document.hosting_eula_uri is not None:
1253+
hosting_eula_bucket, hosting_eula_key = parse_s3_url(
1254+
hub_content_document.hosting_eula_uri
1255+
)
1256+
self.hosting_eula_key = hosting_eula_key
1257+
1258+
self.hosting_model_package_arns: Optional[Dict] = None # TODO: Missing from shcema?
1259+
self.hosting_use_script_uri: bool = hub_content_document.hosting_use_script_uri
1260+
1261+
self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = (
1262+
JumpStartInstanceTypeVariants(hub_content_document.hosting_instance_type_variants)
1263+
if hub_content_document.hosting_instance_type_variants
1264+
else None
1265+
)
1266+
1267+
if self.training_supported:
1268+
self.training_ecr_uri: Optional[str] = hub_content_document.training_ecr_uri
1269+
self._non_serializable_slots.append("training_ecr_specs")
1270+
training_artifact_bucket, training_artifact_key = parse_s3_url(
1271+
hub_content_document.training_artifact_uri
1272+
)
1273+
self.training_artifact_key: str = training_artifact_key
1274+
training_script_bucket, training_script_key = parse_s3_url(
1275+
hub_content_document.training_script_uri
1276+
)
1277+
self.training_script_key: str = training_script_key
1278+
1279+
self.hyperparameters: List[
1280+
JumpStartHyperparameter
1281+
] = hub_content_document.hyperparameters
1282+
self.training_volume_size: Optional[int] = hub_content_document.training_volume_size
1283+
self.training_enable_network_isolation: bool = (
1284+
hub_content_document.training_enable_network_isolation
1285+
)
1286+
self.training_model_package_artifact_uris: Optional[
1287+
Dict
1288+
] = hub_content_document.training_model_package_artifact_uri
1289+
self.training_instance_type_variants: Optional[
1290+
JumpStartInstanceTypeVariants
1291+
] = JumpStartInstanceTypeVariants(
1292+
hub_content_document.training_instance_type_variants
1293+
if hub_content_document.training_instance_type_variants
1294+
else None
1295+
)
1296+
11541297
def supports_prepacked_inference(self) -> bool:
11551298
"""Returns True if the model has a prepacked inference artifact."""
11561299
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py

+32
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,38 @@ def test_generate_hub_arn_for_init_kwargs():
147147
utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn
148148
)
149149

150+
assert (
151+
utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session)
152+
== hub_arn
153+
)
154+
155+
156+
def test_generate_default_hub_bucket_name():
157+
mock_sagemaker_session = Mock()
158+
mock_sagemaker_session.account_id.return_value = "123456789123"
159+
mock_sagemaker_session.boto_region_name = "us-east-1"
160+
161+
assert (
162+
utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session)
163+
== "sagemaker-hubs-us-east-1-123456789123"
164+
)
165+
166+
167+
def test_create_hub_bucket_if_it_does_not_exist():
168+
mock_sagemaker_session = Mock()
169+
mock_sagemaker_session.account_id.return_value = "123456789123"
170+
mock_sagemaker_session.client("sts").get_caller_identity.return_value = {
171+
"Account": "123456789123"
172+
}
173+
mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
174+
mock_sagemaker_session.boto_region_name = "us-east-1"
175+
bucket_name = "sagemaker-hubs-us-east-1-123456789123"
176+
created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist(
177+
sagemaker_session=mock_sagemaker_session
178+
)
179+
180+
mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once()
181+
assert created_hub_bucket_name == bucket_name
150182
assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn
151183

152184

tests/unit/sagemaker/jumpstart/test_cache.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
2929
JumpStartModelsCache,
3030
)
31+
from sagemaker.session_settings import SessionSettings
3132
from sagemaker.jumpstart.constants import (
3233
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
3334
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -1133,7 +1134,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11331134

11341135
mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root")
11351136
assert mocked_is_dir.call_count == 2
1136-
mocked_open.assert_not_called()
1137+
assert mocked_open.call_count == 2
11371138
mocked_get_json_file_and_etag_from_s3.assert_has_calls(
11381139
calls=[
11391140
call("models_manifest.json"),

0 commit comments

Comments
 (0)