@@ -1151,6 +1151,149 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
1151
1151
)
1152
1152
self .model_subscription_link = json_obj .get ("model_subscription_link" )
1153
1153
1154
+ def from_describe_hub_content_response (self , response : DescribeHubContentResponse ) -> None :
1155
+ """Sets fields in object based on values in HubContentDocument
1156
+
1157
+ Args:
1158
+ hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
1159
+ from SageMaker:DescribeHubContent
1160
+ """
1161
+ self .model_id : str = response .hub_content_name
1162
+ self .version : str = response .hub_content_version
1163
+ hub_content_document : HubModelDocument = response .hub_content_document
1164
+ self .url : str = hub_content_document .url
1165
+ self .min_sdk_version : str = hub_content_document .min_sdk_version
1166
+ self .training_supported : bool = hub_content_document .training_supported
1167
+ self .incremental_training_supported : bool = bool (
1168
+ hub_content_document ["IncrementalTrainingSupported" ]
1169
+ )
1170
+ self .hosting_ecr_uri : Optional [str ] = hub_content_document .hosting_ecr_uri
1171
+ self ._non_serializable_slots .append ("hosting_ecr_specs" )
1172
+
1173
+ hosting_artifact_bucket , hosting_artifact_key = parse_s3_url (
1174
+ hub_content_document .hosting_artifact_uri
1175
+ )
1176
+ self .hosting_artifact_key : str = hosting_artifact_key
1177
+ hosting_script_bucket , hosting_script_key = parse_s3_url (
1178
+ hub_content_document .hosting_script_uri
1179
+ )
1180
+ self .hosting_script_key : str = hosting_script_key
1181
+ self .inference_environment_variables = hub_content_document .inference_environment_variables
1182
+ self .inference_vulnerable : bool = False
1183
+ self .inference_dependencies : List [str ] = hub_content_document .inference_dependencies
1184
+ self .inference_vulnerabilities : List [str ] = []
1185
+ self .training_vulnerable : bool = False
1186
+ self .training_dependencies : List [str ] = hub_content_document .training_dependencies
1187
+ self .training_vulnerabilities : List [str ] = []
1188
+ self .deprecated : bool = False
1189
+ self .deprecated_message : Optional [str ] = None
1190
+ self .deprecate_warn_message : Optional [str ] = None
1191
+ self .usage_info_message : Optional [str ] = None
1192
+ self .default_inference_instance_type : Optional [
1193
+ str
1194
+ ] = hub_content_document .default_inference_instance_type
1195
+ self .default_training_instance_type : Optional [
1196
+ str
1197
+ ] = hub_content_document .default_training_instance_type
1198
+ self .supported_inference_instance_types : Optional [
1199
+ List [str ]
1200
+ ] = hub_content_document .supported_inference_instance_types
1201
+ self .supported_training_instance_types : Optional [
1202
+ List [str ]
1203
+ ] = hub_content_document .supported_training_instance_types
1204
+ self .dynamic_container_deployment_supported : Optional [
1205
+ bool
1206
+ ] = hub_content_document .dynamic_container_deployment_supported
1207
+ self .hosting_resource_requirements : Optional [
1208
+ Dict [str , int ]
1209
+ ] = hub_content_document .hosting_resource_requirements
1210
+ self .metrics : Optional [List [Dict [str , str ]]] = hub_content_document .training_metrics
1211
+ self .training_prepacked_script_key : Optional [str ] = None
1212
+ if hub_content_document .training_prepacked_script_uri is not None :
1213
+ training_prepacked_script_bucket , training_prepacked_script_key = parse_s3_url (
1214
+ hub_content_document .training_prepacked_script_uri
1215
+ )
1216
+ self .training_prepacked_script_key = training_prepacked_script_key
1217
+
1218
+ self .hosting_prepacked_artifact_key : Optional [str ] = None
1219
+ if hub_content_document .hosting_prepacked_artifact_uri is not None :
1220
+ hosting_prepacked_artifact_bucket , hosting_prepacked_artifact_key = parse_s3_url (
1221
+ hub_content_document .hosting_prepacked_artifact_uri
1222
+ )
1223
+ self .hosting_prepacked_artifact_key = hosting_prepacked_artifact_key
1224
+
1225
+ self .fit_kwargs = get_model_spec_kwargs_from_hub_content_document (
1226
+ ModelSpecKwargType .FIT , hub_content_document
1227
+ )
1228
+ self .model_kwargs = get_model_spec_kwargs_from_hub_content_document (
1229
+ ModelSpecKwargType .MODEL , hub_content_document
1230
+ )
1231
+ self .deploy_kwargs = get_model_spec_kwargs_from_hub_content_document (
1232
+ ModelSpecKwargType .DEPLOY , hub_content_document
1233
+ )
1234
+ self .estimator_kwargs = get_model_spec_kwargs_from_hub_content_document (
1235
+ ModelSpecKwargType .ESTIMATOR , hub_content_document
1236
+ )
1237
+
1238
+ self .predictor_specs : Optional [
1239
+ JumpStartPredictorSpecs
1240
+ ] = hub_content_document .sage_maker_sdk_predictor_specifications
1241
+ self .default_payloads : Optional [
1242
+ Dict [str , JumpStartSerializablePayload ]
1243
+ ] = hub_content_document .default_payloads
1244
+ self .gated_bucket = hub_content_document .gated_bucket
1245
+ self .inference_volume_size : Optional [int ] = hub_content_document .inference_volume_size
1246
+ self .inference_enable_network_isolation : bool = (
1247
+ hub_content_document .inference_enable_network_isolation
1248
+ )
1249
+ self .resource_name_base : Optional [str ] = hub_content_document .resource_name_base
1250
+
1251
+ self .hosting_eula_key : Optional [str ] = None
1252
+ if hub_content_document .hosting_eula_uri is not None :
1253
+ hosting_eula_bucket , hosting_eula_key = parse_s3_url (
1254
+ hub_content_document .hosting_eula_uri
1255
+ )
1256
+ self .hosting_eula_key = hosting_eula_key
1257
+
1258
+ self .hosting_model_package_arns : Optional [Dict ] = None # TODO: Missing from shcema?
1259
+ self .hosting_use_script_uri : bool = hub_content_document .hosting_use_script_uri
1260
+
1261
+ self .hosting_instance_type_variants : Optional [JumpStartInstanceTypeVariants ] = (
1262
+ JumpStartInstanceTypeVariants (hub_content_document .hosting_instance_type_variants )
1263
+ if hub_content_document .hosting_instance_type_variants
1264
+ else None
1265
+ )
1266
+
1267
+ if self .training_supported :
1268
+ self .training_ecr_uri : Optional [str ] = hub_content_document .training_ecr_uri
1269
+ self ._non_serializable_slots .append ("training_ecr_specs" )
1270
+ training_artifact_bucket , training_artifact_key = parse_s3_url (
1271
+ hub_content_document .training_artifact_uri
1272
+ )
1273
+ self .training_artifact_key : str = training_artifact_key
1274
+ training_script_bucket , training_script_key = parse_s3_url (
1275
+ hub_content_document .training_script_uri
1276
+ )
1277
+ self .training_script_key : str = training_script_key
1278
+
1279
+ self .hyperparameters : List [
1280
+ JumpStartHyperparameter
1281
+ ] = hub_content_document .hyperparameters
1282
+ self .training_volume_size : Optional [int ] = hub_content_document .training_volume_size
1283
+ self .training_enable_network_isolation : bool = (
1284
+ hub_content_document .training_enable_network_isolation
1285
+ )
1286
+ self .training_model_package_artifact_uris : Optional [
1287
+ Dict
1288
+ ] = hub_content_document .training_model_package_artifact_uri
1289
+ self .training_instance_type_variants : Optional [
1290
+ JumpStartInstanceTypeVariants
1291
+ ] = JumpStartInstanceTypeVariants (
1292
+ hub_content_document .training_instance_type_variants
1293
+ if hub_content_document .training_instance_type_variants
1294
+ else None
1295
+ )
1296
+
1154
1297
def supports_prepacked_inference (self ) -> bool :
1155
1298
"""Returns True if the model has a prepacked inference artifact."""
1156
1299
return getattr (self , "hosting_prepacked_artifact_key" , None ) is not None
0 commit comments