|
18 | 18 | import pytest
|
19 | 19 |
|
20 | 20 | from sagemaker import environment_variables
|
| 21 | +from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket |
21 | 22 | from sagemaker.jumpstart.enums import JumpStartModelType
|
22 | 23 |
|
23 | 24 | from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
|
@@ -204,6 +205,70 @@ def test_jumpstart_sdk_environment_variables(
|
204 | 205 | )
|
205 | 206 |
|
206 | 207 |
|
| 208 | +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 209 | +def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs): |
| 210 | + |
| 211 | + patched_get_model_specs.side_effect = get_special_model_spec |
| 212 | + |
| 213 | + model_id = "gemma-model-1-artifact" |
| 214 | + region = "us-west-2" |
| 215 | + |
| 216 | + assert { |
| 217 | + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" |
| 218 | + "huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz" |
| 219 | + } == environment_variables.retrieve_default( |
| 220 | + region=region, |
| 221 | + model_id=model_id, |
| 222 | + model_version="*", |
| 223 | + include_aws_sdk_env_vars=False, |
| 224 | + sagemaker_session=mock_session, |
| 225 | + instance_type="ml.p3.2xlarge", |
| 226 | + script="training", |
| 227 | + ) |
| 228 | + |
| 229 | + |
| 230 | +@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER") |
| 231 | +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") |
| 232 | +def test_jumpstart_sdk_environment_variables_no_gated_env_var_available( |
| 233 | + patched_get_model_specs, patched_jumpstart_logger |
| 234 | +): |
| 235 | + |
| 236 | + patched_get_model_specs.side_effect = get_special_model_spec |
| 237 | + |
| 238 | + model_id = "gemma-model" |
| 239 | + region = "us-west-2" |
| 240 | + |
| 241 | + assert {} == environment_variables.retrieve_default( |
| 242 | + region=region, |
| 243 | + model_id=model_id, |
| 244 | + model_version="*", |
| 245 | + include_aws_sdk_env_vars=False, |
| 246 | + sagemaker_session=mock_session, |
| 247 | + instance_type="ml.p3.2xlarge", |
| 248 | + script="training", |
| 249 | + ) |
| 250 | + |
| 251 | + patched_jumpstart_logger.warning.assert_called_once_with( |
| 252 | + "'gemma-model' does not support ml.p3.2xlarge instance type for " |
| 253 | + "training. Please use one of the following instance types: " |
| 254 | + "ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge." |
| 255 | + ) |
| 256 | + |
| 257 | + # assert that supported instance types succeed |
| 258 | + assert { |
| 259 | + "SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/" |
| 260 | + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz" |
| 261 | + } == environment_variables.retrieve_default( |
| 262 | + region=region, |
| 263 | + model_id=model_id, |
| 264 | + model_version="*", |
| 265 | + include_aws_sdk_env_vars=False, |
| 266 | + sagemaker_session=mock_session, |
| 267 | + instance_type="ml.g5.24xlarge", |
| 268 | + script="training", |
| 269 | + ) |
| 270 | + |
| 271 | + |
207 | 272 | @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
|
208 | 273 | def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
|
209 | 274 |
|
|
0 commit comments