Skip to content

feature: support training for JumpStart model references as part of Curated Hub Phase 2 #5070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,6 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
raise ValueError(
"File URIs are supported in local mode only. Please use a S3 URI instead."
)

config = _Job._load_config(inputs, estimator)

current_hyperparameters = estimator.hyperparameters()
Expand Down
30 changes: 30 additions & 0 deletions src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
attribute_names: Optional[List[Union[str, PipelineVariable]]] = None,
target_attribute_name: Optional[Union[str, PipelineVariable]] = None,
shuffle_config: Optional["ShuffleConfig"] = None,
hub_access_config: Optional[dict] = None,
model_access_config: Optional[dict] = None,
):
r"""Create a definition for input data used by an SageMaker training job.

Expand Down Expand Up @@ -102,6 +104,13 @@ def __init__(
shuffle_config (sagemaker.inputs.ShuffleConfig): If specified this configuration enables
shuffling on this channel. See the SageMaker API documentation for more info:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
hub_access_config (dict): Specify the HubAccessConfig of a
Model Reference for which a training job is being created for.
model_access_config (dict): For models that require a Model Access Config, specify True
or False for to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
"""
self.config = {
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
Expand Down Expand Up @@ -129,6 +138,27 @@ def __init__(
self.config["TargetAttributeName"] = target_attribute_name
if shuffle_config is not None:
self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed}
self.add_hub_access_config(hub_access_config)
self.add_model_access_config(model_access_config)

def add_hub_access_config(self, hub_access_config=None):
"""Add Hub Access Config to the channel's configuration.

Args:
hub_access_config (dict): The HubAccessConfig to be added to the
channel's configuration.
"""
if hub_access_config is not None:
self.config["DataSource"]["S3DataSource"]["HubAccessConfig"] = hub_access_config

def add_model_access_config(self, model_access_config=None):
"""Add Model Access Config to the channel's configuration.

Args:
model_access_config (dict): Whether model terms of use have been accepted.
"""
if model_access_config is not None:
self.config["DataSource"]["S3DataSource"]["ModelAccessConfig"] = model_access_config


class ShuffleConfig(object):
Expand Down
55 changes: 46 additions & 9 deletions src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def stop(self):
@staticmethod
def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
"""Placeholder docstring"""
model_access_config, hub_access_config = _Job._get_access_configs(estimator)
input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
role = (
estimator.sagemaker_session.expand_role(estimator.role)
Expand Down Expand Up @@ -95,19 +96,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
validate_uri,
content_type="application/x-sagemaker-model",
input_mode="File",
model_access_config=model_access_config,
hub_access_config=hub_access_config,
)
if model_channel:
input_config = [] if input_config is None else input_config
input_config.append(model_channel)

if estimator.enable_network_isolation():
code_channel = _Job._prepare_channel(
input_config, estimator.code_uri, estimator.code_channel_name, validate_uri
)
code_channel = _Job._prepare_channel(
input_config,
estimator.code_uri,
estimator.code_channel_name,
validate_uri,
)

if code_channel:
input_config = [] if input_config is None else input_config
input_config.append(code_channel)
if code_channel:
input_config = [] if input_config is None else input_config
input_config.append(code_channel)

return {
"input_config": input_config,
Expand All @@ -118,6 +123,23 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
"vpc_config": vpc_config,
}

@staticmethod
def _get_access_configs(estimator):
"""Return access configs from estimator object.

JumpStartEstimator uses access configs which need to be added to the model channel,
so they are passed down to the job level.

Args:
estimator (EstimatorBase): estimator object with access config field if applicable
"""
model_access_config, hub_access_config = None, None
if hasattr(estimator, "model_access_config"):
model_access_config = estimator.model_access_config
if hasattr(estimator, "hub_access_config"):
hub_access_config = estimator.hub_access_config
return model_access_config, hub_access_config

@staticmethod
def _format_inputs_to_input_config(inputs, validate_uri=True):
"""Placeholder docstring"""
Expand Down Expand Up @@ -173,6 +195,8 @@ def _format_string_uri_input(
input_mode=None,
compression=None,
target_attribute_name=None,
model_access_config=None,
hub_access_config=None,
):
"""Placeholder docstring"""
s3_input_result = TrainingInput(
Expand All @@ -181,6 +205,8 @@ def _format_string_uri_input(
input_mode=input_mode,
compression=compression,
target_attribute_name=target_attribute_name,
model_access_config=model_access_config,
hub_access_config=hub_access_config,
)
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
return s3_input_result
Expand All @@ -193,7 +219,11 @@ def _format_string_uri_input(
)
if isinstance(uri_input, str):
return s3_input_result
if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)):
if isinstance(uri_input, (file_input, FileSystemInput)):
return uri_input
if isinstance(uri_input, TrainingInput):
uri_input.add_hub_access_config(hub_access_config=hub_access_config)
uri_input.add_model_access_config(model_access_config=model_access_config)
return uri_input
if is_pipeline_variable(uri_input):
return s3_input_result
Expand All @@ -211,6 +241,8 @@ def _prepare_channel(
validate_uri=True,
content_type=None,
input_mode=None,
model_access_config=None,
hub_access_config=None,
):
"""Placeholder docstring"""
if not channel_uri:
Expand All @@ -226,7 +258,12 @@ def _prepare_channel(
raise ValueError("Duplicate channel {} not allowed.".format(channel_name))

channel_input = _Job._format_string_uri_input(
channel_uri, validate_uri, content_type, input_mode
channel_uri,
validate_uri,
content_type,
input_mode,
model_access_config=model_access_config,
hub_access_config=hub_access_config,
)
channel = _Job._convert_input_to_channel(channel_name, channel_input)

Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_region_fallback,
verify_model_region_and_return_specs,
)
from sagemaker.s3_utils import is_s3_url
from sagemaker.session import Session
from sagemaker.jumpstart.types import JumpStartModelSpecs

Expand Down Expand Up @@ -74,7 +75,7 @@ def _retrieve_hosting_artifact_key(model_specs: JumpStartModelSpecs, instance_ty
def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_type: str) -> str:
"""Returns instance specific training artifact key or default one as fallback."""
instance_specific_training_artifact_key: Optional[str] = (
model_specs.training_instance_type_variants.get_instance_specific_artifact_key(
model_specs.training_instance_type_variants.get_instance_specific_training_artifact_key(
instance_type=instance_type
)
if instance_type
Expand Down Expand Up @@ -185,8 +186,8 @@ def _retrieve_model_uri(
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
or default_jumpstart_bucket
)

model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
if not is_s3_url(model_artifact_key):
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"

return model_s3_uri

Expand Down
20 changes: 19 additions & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
validate_model_id_and_get_type,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
remove_env_var_from_estimator_kwargs_if_accept_eula_present,
get_model_access_config,
get_hub_access_config,
)
from sagemaker.utils import stringify_object, format_tags, Tags
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
Expand Down Expand Up @@ -619,6 +622,10 @@ def _validate_model_id_and_get_type_hook():
self._enable_network_isolation = estimator_init_kwargs.enable_network_isolation
self.config_name = estimator_init_kwargs.config_name
self.init_kwargs = estimator_init_kwargs.to_kwargs_dict(False)
# Access configs initialized to None, would be given a value when .fit() is called
# if applicable
self.model_access_config = None
self.hub_access_config = None

super(JumpStartEstimator, self).__init__(**estimator_init_kwargs.to_kwargs_dict())

Expand All @@ -629,6 +636,7 @@ def fit(
logs: Optional[str] = None,
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
accept_eula: Optional[bool] = None,
) -> None:
"""Start training job by calling base ``Estimator`` class ``fit`` method.

Expand Down Expand Up @@ -679,8 +687,16 @@ def fit(
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
(Default: None).
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
"""

self.model_access_config = get_model_access_config(accept_eula)
self.hub_access_config = get_hub_access_config(
hub_content_arn=self.init_kwargs.get("model_reference_arn", None)
)
estimator_fit_kwargs = get_fit_kwargs(
model_id=self.model_id,
model_version=self.model_version,
Expand All @@ -695,7 +711,9 @@ def fit(
tolerate_deprecated_model=self.tolerate_deprecated_model,
sagemaker_session=self.sagemaker_session,
config_name=self.config_name,
hub_access_config=self.hub_access_config,
)
remove_env_var_from_estimator_kwargs_if_accept_eula_present(self.init_kwargs, accept_eula)

return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict())

Expand Down
36 changes: 23 additions & 13 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
from sagemaker.jumpstart.utils import (
add_hub_content_arn_tags,
add_jumpstart_model_info_tags,
get_eula_message,
get_default_jumpstart_session_with_user_agent_suffix,
get_top_ranked_config_name,
update_dict_if_key_not_present,
Expand Down Expand Up @@ -265,6 +264,7 @@ def get_fit_kwargs(
tolerate_deprecated_model: Optional[bool] = None,
sagemaker_session: Optional[Session] = None,
config_name: Optional[str] = None,
hub_access_config: Optional[Dict] = None,
) -> JumpStartEstimatorFitKwargs:
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -301,10 +301,32 @@ def get_fit_kwargs(
estimator_fit_kwargs = _add_region_to_kwargs(estimator_fit_kwargs)
estimator_fit_kwargs = _add_training_job_name_to_kwargs(estimator_fit_kwargs)
estimator_fit_kwargs = _add_fit_extra_kwargs(estimator_fit_kwargs)
estimator_fit_kwargs = _add_hub_access_config_to_kwargs_inputs(
estimator_fit_kwargs, hub_access_config
)

return estimator_fit_kwargs


def _add_hub_access_config_to_kwargs_inputs(
kwargs: JumpStartEstimatorFitKwargs, hub_access_config=None
):
"""Adds HubAccessConfig to kwargs inputs"""

if isinstance(kwargs.inputs, str):
kwargs.inputs = TrainingInput(s3_data=kwargs.inputs, hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs.add_hub_access_config(hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, dict):
for k, v in kwargs.inputs.items():
if isinstance(v, str):
kwargs.inputs[k] = TrainingInput(s3_data=v, hub_access_config=hub_access_config)
elif isinstance(kwargs.inputs, TrainingInput):
kwargs.inputs[k].add_hub_access_config(hub_access_config=hub_access_config)

return kwargs


def get_deploy_kwargs(
model_id: str,
model_version: Optional[str] = None,
Expand Down Expand Up @@ -668,18 +690,6 @@ def _add_env_to_kwargs(
value,
)

environment = getattr(kwargs, "environment", {}) or {}
if (
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
and str(environment.get("accept_eula", "")).lower() != "true"
):
model_specs = kwargs.specs
if model_specs.is_gated_model():
raise ValueError(
"Need to define ‘accept_eula'='true' within Environment. "
f"{get_eula_message(model_specs, kwargs.region)}"
)

return kwargs


Expand Down
13 changes: 13 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,19 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
instance_type=instance_type, property_name="artifact_key"
)

def get_instance_specific_training_artifact_key(self, instance_type: str) -> Optional[str]:
"""Returns instance specific training artifact key.

Returns None if a model, instance type tuple does not have specific
training artifact key.
"""

return self._get_instance_specific_property(
instance_type=instance_type, property_name="training_artifact_uri"
) or self._get_instance_specific_property(
instance_type=instance_type, property_name="training_artifact_key"
)

def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
"""Returns instance specific resource requirements.

Expand Down
41 changes: 41 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,6 +1632,47 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
return neo_bucket


def remove_env_var_from_estimator_kwargs_if_accept_eula_present(
init_kwargs: dict, accept_eula: Optional[bool]
):
"""Remove env vars if access configs are used

Args:
init_kwargs (dict): Dictionary of kwargs when Estimator is instantiated.
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
"""
if accept_eula is not None and init_kwargs["environment"]:
del init_kwargs["environment"][constants.SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY]


def get_hub_access_config(hub_content_arn: Optional[str]):
"""Get hub access config

Args:
hub_content_arn (Optional[bool]): Arn of the model reference hub content
"""
if hub_content_arn is not None:
hub_access_config = {"HubContentArn": hub_content_arn}
else:
hub_access_config = None

return hub_access_config


def get_model_access_config(accept_eula: Optional[bool]):
"""Get access configs

Args:
accept_eula (Optional[bool]): Whether or not the EULA was accepted, optionally passed in to Estimator.fit().
"""
if accept_eula is not None:
model_access_config = {"AcceptEula": accept_eula}
else:
model_access_config = None

return model_access_config


def get_latest_version(versions: List[str]) -> Optional[str]:
"""Returns the latest version using sem-ver when possible."""
try:
Expand Down
Loading