Skip to content

Commit 81ecffa

Browse files
xiongz945pintaoz-aws
authored andcommitted
Feature: ModelBuilder supports HuggingFace Models with benchmark data and deployment configs (#1572)
1 parent 869b75f commit 81ecffa

File tree

2 files changed

+592
-18
lines changed

2 files changed

+592
-18
lines changed

src/sagemaker/serve/builder/model_builder.py

+123-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import importlib.util
17+
import json
1718
import uuid
1819
from typing import Any, Type, List, Dict, Optional, Union
1920
from dataclasses import dataclass, field
@@ -24,6 +25,8 @@
2425
from pathlib import Path
2526

2627
from sagemaker.enums import Tag
28+
from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor
29+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2730
from sagemaker.s3 import S3Downloader
2831

2932
from sagemaker import Session
@@ -879,7 +882,7 @@ def build( # pylint: disable=R0911
879882

880883
if isinstance(self.model, str):
881884
model_task = None
882-
if self._is_jumpstart_model_id():
885+
if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent():
883886
self.model_hub = ModelHub.JUMPSTART
884887
return self._build_for_jumpstart()
885888
self.model_hub = ModelHub.HUGGINGFACE
@@ -1514,3 +1517,122 @@ def _optimize_prepare_for_hf(self):
15141517
should_upload_artifacts=True,
15151518
)
15161519
self.pysdk_model.env.update(env)
1520+
1521+
def display_benchmark_metrics(self, **kwargs):
1522+
"""Display Markdown Benchmark Metrics for deployment configs."""
1523+
if not isinstance(self.model, str):
1524+
raise ValueError("Benchmarking is only supported for JumpStart or HuggingFace models")
1525+
if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent():
1526+
return super().display_benchmark_metrics(**kwargs)
1527+
else:
1528+
raise ValueError("This model does not have benchmark metrics yet")
1529+
1530+
def get_deployment_config(self) -> Optional[Dict[str, Any]]:
1531+
"""Gets the deployment config to apply to the model.
1532+
1533+
Returns:
1534+
Optional[Dict[str, Any]]: Deployment config to apply to this model.
1535+
"""
1536+
if not isinstance(self.model, str):
1537+
raise ValueError(
1538+
"Deployment config is only supported for JumpStart or HuggingFace models"
1539+
)
1540+
if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent():
1541+
return super().get_deployment_config()
1542+
else:
1543+
raise ValueError("This model does not have any deployment config yet")
1544+
1545+
def list_deployment_configs(self) -> List[Dict[str, Any]]:
1546+
"""List deployment configs for the model in the current region.
1547+
1548+
Returns:
1549+
List[Dict[str, Any]]: A list of deployment configs.
1550+
"""
1551+
if not isinstance(self.model, str):
1552+
raise ValueError(
1553+
"Deployment config is only supported for JumpStart or HuggingFace models"
1554+
)
1555+
if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent():
1556+
return super().list_deployment_configs()
1557+
else:
1558+
raise ValueError("This model does not have any deployment config yet")
1559+
1560+
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
1561+
"""Sets the deployment config to apply to the model.
1562+
1563+
Args:
1564+
config_name (str):
1565+
The name of the deployment config to apply to the model.
1566+
Call list_deployment_configs to see the list of config names.
1567+
instance_type (str):
1568+
The instance_type that the model will use after setting
1569+
the config.
1570+
"""
1571+
if not isinstance(self.model, str):
1572+
raise ValueError(
1573+
"Deployment config is only supported for JumpStart or HuggingFace models"
1574+
)
1575+
if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent():
1576+
logger.warning(
1577+
"If there are existing deployment configurations, "
1578+
"they will be overwritten by the config %s",
1579+
config_name,
1580+
)
1581+
return super().set_deployment_config(config_name, instance_type)
1582+
else:
1583+
raise ValueError(f"The deployment config {config_name} cannot be set on this model")
1584+
1585+
def _use_jumpstart_equivalent(self):
1586+
"""Check if the HuggingFace model has a JumpStart equivalent.
1587+
1588+
Replace it with the equivalent if there's one
1589+
"""
1590+
if not hasattr(self, "_has_jumpstart_equivalent"):
1591+
self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping()
1592+
self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping
1593+
if self._has_jumpstart_equivalent:
1594+
huggingface_model_id = self.model
1595+
jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"]
1596+
self.model = jumpstart_model_id
1597+
merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at")
1598+
self._build_for_jumpstart()
1599+
compare_model_diff_message = (
1600+
"If you want to identify the differences between the two, "
1601+
"please use model_uris.retrieve() to retrieve the model "
1602+
"artifact S3 URI and compare them."
1603+
)
1604+
logger.warning( # pylint: disable=logging-fstring-interpolation
1605+
"Please note that for this model we are using the JumpStart's"
1606+
f'local copy "{jumpstart_model_id}" '
1607+
f'of the HuggingFace model "{huggingface_model_id}" you chose. '
1608+
"We strive to keep our local copy synced with the HF model hub closely. "
1609+
"This model was synced "
1610+
f"{f'on {merged_date}' if merged_date else 'before 11/04/2024'}. "
1611+
f"{compare_model_diff_message if not self._is_gated_model() else ''}"
1612+
)
1613+
return True
1614+
return False
1615+
1616+
def _retrieve_hugging_face_model_mapping(self):
1617+
"""Retrieve the HuggingFace/JumpStart model mapping and preprocess it."""
1618+
converted_mapping = {}
1619+
region = self.sagemaker_session.boto_region_name
1620+
try:
1621+
mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached(
1622+
bucket=get_jumpstart_content_bucket(region),
1623+
key="hf_model_id_map_cache.json",
1624+
region=region,
1625+
s3_client=self.sagemaker_session.s3_client,
1626+
)
1627+
mapping = json.loads(mapping_json_object)
1628+
except Exception: # pylint: disable=broad-except
1629+
return converted_mapping
1630+
1631+
for k, v in mapping.items():
1632+
converted_mapping[v["hf-model-id"]] = {
1633+
"jumpstart-model-id": k,
1634+
"jumpstart-model-version": v["jumpstart-model-version"],
1635+
"merged-at": v.get("merged-at"),
1636+
"hf-model-repo-sha": v.get("hf-model-repo-sha"),
1637+
}
1638+
return converted_mapping

0 commit comments

Comments
 (0)