Skip to content

chore: emit warning when no instance specific gated training env var is available, and raise exception when accept_eula flag is not supplied #4485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 13, 2024
7 changes: 7 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def _retrieve_default_environment_variables(
instance_type=instance_type,
)

if gated_model_env_var is None and model_specs.gated_bucket:
raise ValueError(
f"'{model_id}' does not support {instance_type} instance type for training. "
"Please use one of the following instance types: "
f"{', '.join(model_specs.supported_training_instance_types)}."
)

if gated_model_env_var is not None:
default_environment_variables.update(
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

from sagemaker import environment_variables
from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket

from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec

Expand Down Expand Up @@ -177,6 +178,46 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(patched_get_model_specs):

patched_get_model_specs.side_effect = get_special_model_spec

model_id = "gemma-model"
region = "us-west-2"

# assert that unsupported instance types raise an exception
with pytest.raises(ValueError) as e:
environment_variables.retrieve_default(
region=region,
model_id=model_id,
model_version="*",
include_aws_sdk_env_vars=False,
sagemaker_session=mock_session,
instance_type="ml.p3.2xlarge",
script="training",
)
assert (
str(e.value) == "'gemma-model' does not support ml.p3.2xlarge instance type for "
"training. Please use one of the following instance types: "
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
)

# assert that supported instance types succeed
assert {
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
"huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz"
} == environment_variables.retrieve_default(
region=region,
model_id=model_id,
model_version="*",
include_aws_sdk_env_vars=False,
sagemaker_session=mock_session,
instance_type="ml.g5.24xlarge",
script="training",
)


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):

Expand Down
Loading