Skip to content

chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied #4485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 13, 2024
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def _retrieve_default_environment_variables(
instance_type=instance_type,
)

if gated_model_env_var is None and model_specs.is_gated_model():
raise ValueError(
f"'{model_id}' does not support {instance_type} instance type for training. "
"Please use one of the following instance types: "
f"{', '.join(model_specs.supported_training_instance_types)}."
)

if gated_model_env_var is not None:
default_environment_variables.update(
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from sagemaker.jumpstart.utils import (
add_jumpstart_model_id_version_tags,
get_eula_message,
update_dict_if_key_not_present,
resolve_estimator_sagemaker_config_field,
verify_model_region_and_return_specs,
Expand Down Expand Up @@ -595,6 +596,26 @@ def _add_env_to_kwargs(
value,
)

environment = getattr(kwargs, "environment", {}) or {}
if (
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
and str(environment.get("accept_eula", "")).lower() != "true"
):
model_specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
region=kwargs.region,
scope=JumpStartScriptScope.TRAINING,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
)
if model_specs.is_gated_model():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: shouldn't this be the outer if statement here? And we would throw an error in all of the cases (the environment does not contain the special key, accept_eula is missing, accept_eula is false, etc.)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather make the conditions as specific as possible. Maybe we'll change how we launch gated models in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this is odd, why bother retrieving the specs if you're not going to use them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw, this won't involve another s3 call, it's just reading from memory

raise ValueError(
"Need to define ‘accept_eula'='true' within Environment. "
f"{get_eula_message(model_specs, kwargs.region)}"
)

return kwargs


Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,10 @@ def use_training_model_artifact(self) -> bool:
# otherwise, return true is a training model package is not set
return len(self.training_model_package_artifact_uris or {}) == 0

def is_gated_model(self) -> bool:
"""Returns True if the model has a EULA key or the model bucket is gated."""
return self.gated_bucket or self.hosting_eula_key is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum, this is odd, shouldn't the pySDK only us gated_bucket as indicator?

Isn't it MH's unit test job to ensure that if hosting_eula_key is passed, then gated_bucket is set correctly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a unit test for this (tho we should). It was confusing (at least for me) to refer to gated buckets and hosting eula keys separately, so to clear the confusion i created this helper function to clarify these are the markers for gated models (if the bucket is gated or there's a eula key).


def supports_incremental_training(self) -> bool:
"""Returns True if the model supports incremental training."""
return self.incremental_training_supported
Expand Down
22 changes: 13 additions & 9 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,21 +466,25 @@ def update_inference_tags_with_jumpstart_training_tags(
return inference_tags


def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
"""Returns EULA message to display to customers if one is available, else empty string."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove 'to customers' from docstring.

if model_specs.hosting_eula_key is None:
return ""
return (
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
f"/{model_specs.hosting_eula_key} for terms of use."
)


def emit_logs_based_on_model_specs(
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
) -> None:
"""Emits logs based on model specs and region."""

if model_specs.hosting_eula_key:
constants.JUMPSTART_LOGGER.info(
"Model '%s' requires accepting end-user license agreement (EULA). "
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
model_specs.model_id,
get_jumpstart_content_bucket(region=region),
region,
".cn" if region.startswith("cn-") else "",
model_specs.hosting_eula_key,
)
constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))

full_version: str = model_specs.version

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

from sagemaker import environment_variables
from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket

from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec

Expand Down Expand Up @@ -177,6 +178,46 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(patched_get_model_specs):

patched_get_model_specs.side_effect = get_special_model_spec

model_id = "gemma-model"
region = "us-west-2"

# assert that unsupported instance types raise an exception
with pytest.raises(ValueError) as e:
environment_variables.retrieve_default(
region=region,
model_id=model_id,
model_version="*",
include_aws_sdk_env_vars=False,
sagemaker_session=mock_session,
instance_type="ml.p3.2xlarge",
script="training",
)
assert (
str(e.value) == "'gemma-model' does not support ml.p3.2xlarge instance type for "
"training. Please use one of the following instance types: "
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
)

# assert that supported instance types succeed
assert {
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
"huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz"
} == environment_variables.retrieve_default(
region=region,
model_id=model_id,
model_version="*",
include_aws_sdk_env_vars=False,
sagemaker_session=mock_session,
instance_type="ml.g5.24xlarge",
script="training",
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):

Expand Down
Loading