Skip to content

Commit c554a9f

Browse files
committed
fix: raise exception when no instance specific gated training env var available
1 parent fcbd0bf commit c554a9f

File tree

3 files changed

+660
-0
lines changed

3 files changed

+660
-0
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

+7
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def _retrieve_default_environment_variables(
120120
instance_type=instance_type,
121121
)
122122

123+
if gated_model_env_var is None and model_specs.gated_bucket:
124+
raise ValueError(
125+
f"'{model_id}' does not support {instance_type} instance type for training. "
126+
"Please use one of the following instance types: "
127+
f"{', '.join(model_specs.supported_training_instance_types)}."
128+
)
129+
123130
if gated_model_env_var is not None:
124131
default_environment_variables.update(
125132
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

+41
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker import environment_variables
21+
from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket
2122

2223
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2324

@@ -177,6 +178,46 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
177178
)
178179

179180

181+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
182+
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(patched_get_model_specs):
183+
184+
patched_get_model_specs.side_effect = get_special_model_spec
185+
186+
model_id = "gemma-model"
187+
region = "us-west-2"
188+
189+
# assert that unsupported instance types raise an exception
190+
with pytest.raises(ValueError) as e:
191+
environment_variables.retrieve_default(
192+
region=region,
193+
model_id=model_id,
194+
model_version="*",
195+
include_aws_sdk_env_vars=False,
196+
sagemaker_session=mock_session,
197+
instance_type="ml.p3.2xlarge",
198+
script="training",
199+
)
200+
assert (
201+
str(e.value) == "'gemma-model' does not support ml.p3.2xlarge instance type for "
202+
"training. Please use one of the following instance types: "
203+
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
204+
)
205+
206+
# assert that supported instance types succeed
207+
assert {
208+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
209+
"huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz"
210+
} == environment_variables.retrieve_default(
211+
region=region,
212+
model_id=model_id,
213+
model_version="*",
214+
include_aws_sdk_env_vars=False,
215+
sagemaker_session=mock_session,
216+
instance_type="ml.g5.24xlarge",
217+
script="training",
218+
)
219+
220+
180221
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
181222
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
182223

0 commit comments

Comments
 (0)