|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 | 16 | import importlib.util
|
| 17 | +import json |
17 | 18 | import uuid
|
18 | 19 | from typing import Any, Type, List, Dict, Optional, Union
|
19 | 20 | from dataclasses import dataclass, field
|
|
24 | 25 | from pathlib import Path
|
25 | 26 |
|
26 | 27 | from sagemaker.enums import Tag
|
| 28 | +from sagemaker.jumpstart.accessors import JumpStartS3PayloadAccessor |
| 29 | +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket |
27 | 30 | from sagemaker.s3 import S3Downloader
|
28 | 31 |
|
29 | 32 | from sagemaker import Session
|
@@ -879,7 +882,7 @@ def build( # pylint: disable=R0911
|
879 | 882 |
|
880 | 883 | if isinstance(self.model, str):
|
881 | 884 | model_task = None
|
882 |
| - if self._is_jumpstart_model_id(): |
| 885 | + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): |
883 | 886 | self.model_hub = ModelHub.JUMPSTART
|
884 | 887 | return self._build_for_jumpstart()
|
885 | 888 | self.model_hub = ModelHub.HUGGINGFACE
|
@@ -1514,3 +1517,122 @@ def _optimize_prepare_for_hf(self):
|
1514 | 1517 | should_upload_artifacts=True,
|
1515 | 1518 | )
|
1516 | 1519 | self.pysdk_model.env.update(env)
|
| 1520 | + |
| 1521 | + def display_benchmark_metrics(self, **kwargs): |
| 1522 | + """Display Markdown Benchmark Metrics for deployment configs.""" |
| 1523 | + if not isinstance(self.model, str): |
| 1524 | + raise ValueError("Benchmarking is only supported for JumpStart or HuggingFace models") |
| 1525 | + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): |
| 1526 | + return super().display_benchmark_metrics(**kwargs) |
| 1527 | + else: |
| 1528 | + raise ValueError("This model does not have benchmark metrics yet") |
| 1529 | + |
| 1530 | + def get_deployment_config(self) -> Optional[Dict[str, Any]]: |
| 1531 | + """Gets the deployment config to apply to the model. |
| 1532 | +
|
| 1533 | + Returns: |
| 1534 | + Optional[Dict[str, Any]]: Deployment config to apply to this model. |
| 1535 | + """ |
| 1536 | + if not isinstance(self.model, str): |
| 1537 | + raise ValueError( |
| 1538 | + "Deployment config is only supported for JumpStart or HuggingFace models" |
| 1539 | + ) |
| 1540 | + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): |
| 1541 | + return super().get_deployment_config() |
| 1542 | + else: |
| 1543 | + raise ValueError("This model does not have any deployment config yet") |
| 1544 | + |
| 1545 | + def list_deployment_configs(self) -> List[Dict[str, Any]]: |
| 1546 | + """List deployment configs for the model in the current region. |
| 1547 | +
|
| 1548 | + Returns: |
| 1549 | + List[Dict[str, Any]]: A list of deployment configs. |
| 1550 | + """ |
| 1551 | + if not isinstance(self.model, str): |
| 1552 | + raise ValueError( |
| 1553 | + "Deployment config is only supported for JumpStart or HuggingFace models" |
| 1554 | + ) |
| 1555 | + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): |
| 1556 | + return super().list_deployment_configs() |
| 1557 | + else: |
| 1558 | + raise ValueError("This model does not have any deployment config yet") |
| 1559 | + |
| 1560 | + def set_deployment_config(self, config_name: str, instance_type: str) -> None: |
| 1561 | + """Sets the deployment config to apply to the model. |
| 1562 | +
|
| 1563 | + Args: |
| 1564 | + config_name (str): |
| 1565 | + The name of the deployment config to apply to the model. |
| 1566 | + Call list_deployment_configs to see the list of config names. |
| 1567 | + instance_type (str): |
| 1568 | + The instance_type that the model will use after setting |
| 1569 | + the config. |
| 1570 | + """ |
| 1571 | + if not isinstance(self.model, str): |
| 1572 | + raise ValueError( |
| 1573 | + "Deployment config is only supported for JumpStart or HuggingFace models" |
| 1574 | + ) |
| 1575 | + if self._is_jumpstart_model_id() or self._use_jumpstart_equivalent(): |
| 1576 | + logger.warning( |
| 1577 | + "If there are existing deployment configurations, " |
| 1578 | + "they will be overwritten by the config %s", |
| 1579 | + config_name, |
| 1580 | + ) |
| 1581 | + return super().set_deployment_config(config_name, instance_type) |
| 1582 | + else: |
| 1583 | + raise ValueError(f"The deployment config {config_name} cannot be set on this model") |
| 1584 | + |
| 1585 | + def _use_jumpstart_equivalent(self): |
| 1586 | + """Check if the HuggingFace model has a JumpStart equivalent. |
| 1587 | +
|
| 1588 | + Replace it with the equivalent if there's one |
| 1589 | + """ |
| 1590 | + if not hasattr(self, "_has_jumpstart_equivalent"): |
| 1591 | + self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping() |
| 1592 | + self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping |
| 1593 | + if self._has_jumpstart_equivalent: |
| 1594 | + huggingface_model_id = self.model |
| 1595 | + jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"] |
| 1596 | + self.model = jumpstart_model_id |
| 1597 | + merged_date = self._jumpstart_mapping[huggingface_model_id].get("merged-at") |
| 1598 | + self._build_for_jumpstart() |
| 1599 | + compare_model_diff_message = ( |
| 1600 | + "If you want to identify the differences between the two, " |
| 1601 | + "please use model_uris.retrieve() to retrieve the model " |
| 1602 | + "artifact S3 URI and compare them." |
| 1603 | + ) |
| 1604 | + logger.warning( # pylint: disable=logging-fstring-interpolation |
| 1605 | + "Please note that for this model we are using the JumpStart's" |
| 1606 | + f'local copy "{jumpstart_model_id}" ' |
| 1607 | + f'of the HuggingFace model "{huggingface_model_id}" you chose. ' |
| 1608 | + "We strive to keep our local copy synced with the HF model hub closely. " |
| 1609 | + "This model was synced " |
| 1610 | + f"{f'on {merged_date}' if merged_date else 'before 11/04/2024'}. " |
| 1611 | + f"{compare_model_diff_message if not self._is_gated_model() else ''}" |
| 1612 | + ) |
| 1613 | + return True |
| 1614 | + return False |
| 1615 | + |
| 1616 | + def _retrieve_hugging_face_model_mapping(self): |
| 1617 | + """Retrieve the HuggingFace/JumpStart model mapping and preprocess it.""" |
| 1618 | + converted_mapping = {} |
| 1619 | + region = self.sagemaker_session.boto_region_name |
| 1620 | + try: |
| 1621 | + mapping_json_object = JumpStartS3PayloadAccessor.get_object_cached( |
| 1622 | + bucket=get_jumpstart_content_bucket(region), |
| 1623 | + key="hf_model_id_map_cache.json", |
| 1624 | + region=region, |
| 1625 | + s3_client=self.sagemaker_session.s3_client, |
| 1626 | + ) |
| 1627 | + mapping = json.loads(mapping_json_object) |
| 1628 | + except Exception: # pylint: disable=broad-except |
| 1629 | + return converted_mapping |
| 1630 | + |
| 1631 | + for k, v in mapping.items(): |
| 1632 | + converted_mapping[v["hf-model-id"]] = { |
| 1633 | + "jumpstart-model-id": k, |
| 1634 | + "jumpstart-model-version": v["jumpstart-model-version"], |
| 1635 | + "merged-at": v.get("merged-at"), |
| 1636 | + "hf-model-repo-sha": v.get("hf-model-repo-sha"), |
| 1637 | + } |
| 1638 | + return converted_mapping |
0 commit comments