|
18 | 18 | import pytest
|
19 | 19 |
|
20 | 20 | from sagemaker import environment_variables
|
| 21 | +from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket |
21 | 22 |
|
22 | 23 | from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
|
23 | 24 |
|
@@ -177,6 +178,46 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
|
177 | 178 | )
|
178 | 179 |
|
179 | 180 |
|
| 181 | +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 182 | +def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(patched_get_model_specs): |
| 183 | + |
| 184 | + patched_get_model_specs.side_effect = get_special_model_spec |
| 185 | + |
| 186 | + model_id = "gemma-model" |
| 187 | + region = "us-west-2" |
| 188 | + |
| 189 | + # assert that unsupported instance types raise an exception |
| 190 | + with pytest.raises(ValueError) as e: |
| 191 | + environment_variables.retrieve_default( |
| 192 | + region=region, |
| 193 | + model_id=model_id, |
| 194 | + model_version="*", |
| 195 | + include_aws_sdk_env_vars=False, |
| 196 | + sagemaker_session=mock_session, |
| 197 | + instance_type="ml.p3.2xlarge", |
| 198 | + script="training", |
| 199 | + ) |
| 200 | + assert ( |
| 201 | + str(e.value) == "'gemma-model' does not support ml.p3.2xlarge instance type for " |
| 202 | + "training. Please use one of the following instance types: " |
| 203 | + "ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge." |
| 204 | + ) |
| 205 | + |
| 206 | + # assert that supported instance types succeed |
| 207 | + assert { |
| 208 | + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" |
| 209 | + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz" |
| 210 | + } == environment_variables.retrieve_default( |
| 211 | + region=region, |
| 212 | + model_id=model_id, |
| 213 | + model_version="*", |
| 214 | + include_aws_sdk_env_vars=False, |
| 215 | + sagemaker_session=mock_session, |
| 216 | + instance_type="ml.g5.24xlarge", |
| 217 | + script="training", |
| 218 | + ) |
| 219 | + |
| 220 | + |
180 | 221 | @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
|
181 | 222 | def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
|
182 | 223 |
|
|
0 commit comments