Skip to content

feature: model optimization #4775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 48 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2c3d606
feat: Benchmark feature initial commit (#1463)
knikure Jun 6, 2024
73bf439
feat: Model class to support AdditionalModelDataSources (#1469)
makungaj1 Jun 10, 2024
c4529e3
feat: additional hosting model data source parsing (#1467)
Captainia Jun 10, 2024
0151209
Add optimize to ModelBuilder (#1468)
grenmester Jun 11, 2024
9a410e5
feat: Added utils for extracting JS data sources (#1471)
akozd Jun 11, 2024
2331dec
fix: update passing additional model data sources to API (#1472)
Captainia Jun 12, 2024
3c7b966
fix: overriding instance specific fields in config components (#1478)
Captainia Jun 13, 2024
f55e3c9
Feat: Add optimize to ModelBuilder JS (#1474)
makungaj1 Jun 14, 2024
c6581ff
feat: use Neo bucket in speculative decoding data source (#1479)
Captainia Jun 14, 2024
997e2ce
feat: add build/deploy support for fine-tuned JS models (#1473)
grenmester Jun 15, 2024
701b788
update: Add optimize to ModelBuilder JS (#1480)
makungaj1 Jun 17, 2024
f3b3504
update: Add optimize to ModelBuilder JS (#1485)
makungaj1 Jun 18, 2024
1f6f876
feat: add quicksilver telemetry (#1482)
grenmester Jun 18, 2024
b07f210
unit: tests for fine tuned JS model support (#1481)
grenmester Jun 18, 2024
262a5eb
fix: use current session and role when setting config (#1493)
Captainia Jun 20, 2024
99345d8
fix: training arn support (#1494)
grenmester Jun 21, 2024
9a3f6ca
Bug bash fixes (#1492)
makungaj1 Jun 21, 2024
114a716
Bug fixes (#1496)
makungaj1 Jun 24, 2024
0ac6014
JS Optimize api ref
Jun 24, 2024
80fb96a
Refactoring
Jun 25, 2024
c41a7ca
Refactoring
Jun 25, 2024
f2062a7
Fix issues
Jun 25, 2024
d3f4274
Channel name
Jun 25, 2024
3e93e95
Channel name
Jun 26, 2024
271a862
Optimization output
Jun 26, 2024
e3995b0
neuron model env
Jun 26, 2024
31d70e6
Merge master into master-benchmark-feature (#1502)
akozd Jul 1, 2024
f1bc99e
feat: Support Alt Configs for Public & Curated Hub (#1505)
akozd Jul 3, 2024
15e26c4
fix: make telemetry logger persist certain information (#1500)
grenmester Jul 3, 2024
26c8696
Optimize support for hf models (#1499)
makungaj1 Jul 3, 2024
6687c56
Fixing bugs (#1506)
makungaj1 Jul 5, 2024
7993b77
Fix public optimize api signature (#1507)
makungaj1 Jul 8, 2024
a263067
Merge branch 'master' into feature-release
makungaj1 Jul 8, 2024
d997612
Refactoring
Jul 8, 2024
7152db2
Integration tests
Jul 8, 2024
0e58562
Merge branch 'master' into feature-release
makungaj1 Jul 8, 2024
0bd6aa8
Skip Alt Config integ tests as metadata aren't fully deployed.
Jul 8, 2024
1e14343
Fix metric column name
Jul 9, 2024
9409031
Refactoring
Jul 9, 2024
ba3d49c
Display API
Jul 9, 2024
2edb2e6
Relax set deployment error handling
Jul 9, 2024
4dca186
Override region for draft model data source
Jul 9, 2024
4474119
use latest boto3
Jul 9, 2024
4d0d1d3
Merge branch 'master' into feature-release
makungaj1 Jul 9, 2024
a7d1bae
EBS Volue
Jul 9, 2024
661a415
model tags
Jul 9, 2024
59edbfb
UT
Jul 9, 2024
b85e2c3
FIX UT
Jul 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ accelerate>=0.24.1,<=0.27.0
schema==0.7.5
tensorflow>=2.1,<=2.16
mlflow>=2.12.2,<2.13
huggingface_hub>=0.23.4
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def read_requirements(filename):
# Declare minimal set for installation
required_packages = [
"attrs>=23.1.0,<24",
"boto3>=1.33.3,<2.0",
"boto3>=1.34.142,<2.0",
"cloudpickle==2.2.1",
"google-pasta",
"numpy>=1.9.0,<2.0",
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default accept type for the model matching the given arguments.

Expand All @@ -105,6 +106,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default accept type to use for the model.

Expand All @@ -125,4 +127,5 @@ def retrieve_default(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default content type for the model matching the given arguments.

Expand All @@ -105,6 +106,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default content type to use for the model.

Expand All @@ -125,6 +127,7 @@ def retrieve_default(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model matching the given arguments.

Expand All @@ -125,6 +126,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
BaseDeserializer: The default deserializer to use for the model.

Expand All @@ -146,4 +148,5 @@ def retrieve_default(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
9 changes: 9 additions & 0 deletions src/sagemaker/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,12 @@ class RoutingStrategy(Enum):
"""The endpoint routes requests to the specific instances that have
more capacity to process them.
"""


class Tag(str, Enum):
"""Enum class for tag keys to apply to models."""

OPTIMIZATION_JOB_NAME = "sagemaker-sdk:optimization-job-name"
SPECULATIVE_DRAFT_MODEL_PROVIDER = "sagemaker-sdk:speculative-draft-model-provider"
FINE_TUNING_MODEL_PATH = "sagemaker-sdk:fine-tuning-model-path"
FINE_TUNING_JOB_NAME = "sagemaker-sdk:fine-tuning-job-name"
3 changes: 3 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Expand Down Expand Up @@ -68,6 +69,7 @@ def retrieve_default(
variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The variables to use for the model.

Expand All @@ -91,4 +93,5 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
instance_type=instance_type,
script=script,
config_name=config_name,
)
25 changes: 25 additions & 0 deletions src/sagemaker/huggingface/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
from __future__ import absolute_import

import os
from typing import Optional
import importlib.util

import urllib.request
from urllib.error import HTTPError, URLError
Expand Down Expand Up @@ -123,3 +125,26 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] =
"Did not find model metadata for the following HuggingFace Model ID %s" % model_id
)
return hf_model_metadata_json


def download_huggingface_model_metadata(
model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None
) -> None:
"""Downloads the HuggingFace Model snapshot via HuggingFace API.

Args:
model_id (str): The HuggingFace Model ID
model_local_path (str): The local path to save the HuggingFace Model snapshot.
hf_hub_token (str): The HuggingFace Hub Token

Raises:
ImportError: If huggingface_hub is not installed.
"""
if not importlib.util.find_spec("huggingface_hub"):
raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed")

from huggingface_hub import snapshot_download

os.makedirs(model_local_path, exist_ok=True)
logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path)
snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token)
3 changes: 3 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.

Expand Down Expand Up @@ -69,6 +70,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The hyperparameters to use for the model.

Expand All @@ -90,6 +92,7 @@ def retrieve_default(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)


Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def retrieve(
inference_tool=None,
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -126,6 +127,7 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
Expand Down Expand Up @@ -166,6 +168,7 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.

Expand Down Expand Up @@ -67,6 +68,7 @@ def retrieve_default(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default instance type to use for the model.

Expand All @@ -92,6 +94,7 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
model_type=model_type,
config_name=config_name,
)


Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _retrieve_default_environment_variables(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.

Expand Down Expand Up @@ -71,6 +72,7 @@ def _retrieve_default_environment_variables(
environment variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the inference environment variables to use for the model.
"""
Expand All @@ -88,6 +90,7 @@ def _retrieve_default_environment_variables(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

default_environment_variables: Dict[str, str] = {}
Expand Down Expand Up @@ -126,6 +129,7 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
config_name=config_name,
)
)

Expand Down Expand Up @@ -173,6 +177,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.

Expand All @@ -198,6 +203,7 @@ def _retrieve_gated_model_uri_env_var_value(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
Expand All @@ -220,6 +226,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

s3_key: Optional[str] = (
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _retrieve_default_hyperparameters(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
):
"""Retrieves the training hyperparameters for the model matching the given arguments.

Expand Down Expand Up @@ -69,6 +70,7 @@ def _retrieve_default_hyperparameters(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the hyperparameters to use for the model.
"""
Expand All @@ -86,6 +88,7 @@ def _retrieve_default_hyperparameters(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

default_hyperparameters: Dict[str, str] = {}
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
):
"""Retrieves the container image URI for JumpStart models.

Expand Down Expand Up @@ -98,6 +99,7 @@ def _retrieve_image_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Expand All @@ -120,6 +122,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if image_scope == JumpStartScriptScope.INFERENCE:
Expand Down Expand Up @@ -213,4 +216,5 @@ def _retrieve_image_uri(
distribution=distribution,
base_framework_version=base_framework_version_override or base_framework_version,
training_compiler_config=training_compiler_config,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> bool:
"""Returns True if the model supports incremental training.

Expand All @@ -57,6 +58,7 @@ def _model_supports_incremental_training(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
bool: the support status for incremental training.
"""
Expand All @@ -74,6 +76,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

return model_specs.supports_incremental_training()
Loading