Skip to content

Commit e239779

Browse files
committed
chore: emit warning when instance type is chosen with no gated training artifacts
1 parent 03b2c9c commit e239779

File tree

4 files changed

+687
-22
lines changed

4 files changed

+687
-22
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart environment variables."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Optional
15+
from typing import Callable, Dict, Optional, Set
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1818
JUMPSTART_DEFAULT_REGION_NAME,
19+
JUMPSTART_LOGGER,
1920
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
2021
)
2122
from sagemaker.jumpstart.enums import (
@@ -110,7 +111,9 @@ def _retrieve_default_environment_variables(
110111

111112
default_environment_variables.update(instance_specific_environment_variables)
112113

113-
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
114+
retrieve_gated_env_var_for_instance_type: Callable[
115+
[str], Optional[str]
116+
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
114117
model_id=model_id,
115118
model_version=model_version,
116119
region=region,
@@ -120,12 +123,32 @@ def _retrieve_default_environment_variables(
120123
instance_type=instance_type,
121124
)
122125

126+
gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type(
127+
instance_type
128+
)
129+
123130
if gated_model_env_var is None and model_specs.is_gated_model():
124-
raise ValueError(
125-
f"'{model_id}' does not support {instance_type} instance type for training. "
126-
"Please use one of the following instance types: "
127-
f"{', '.join(model_specs.supported_training_instance_types)}."
128-
)
131+
132+
possible_env_vars: Set[str] = {
133+
retrieve_gated_env_var_for_instance_type(instance_type)
134+
for instance_type in model_specs.supported_training_instance_types
135+
}
136+
137+
# If all officially supported instance types have the same underlying artifact,
138+
# we can use this artifact with high confidence that it'll succeed with
139+
# an arbitrary instance.
140+
if len(possible_env_vars) == 1:
141+
gated_model_env_var = list(possible_env_vars)[0]
142+
143+
# If this model does not have 1 artifact for all supported instance types,
144+
# we cannot determine which artifact to use for an arbitrary instance.
145+
else:
146+
log_msg = (
147+
f"'{model_id}' does not support {instance_type} instance type"
148+
" for training. Please use one of the following instance types: "
149+
f"{', '.join(model_specs.supported_training_instance_types)}."
150+
)
151+
JUMPSTART_LOGGER.warning(log_msg)
129152

130153
if gated_model_env_var is not None:
131154
default_environment_variables.update(

src/sagemaker/jumpstart/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def update_inference_tags_with_jumpstart_training_tags(
467467

468468

469469
def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
470-
"""Returns EULA message to display to customers if one is available, else empty string."""
470+
"""Returns EULA message to display if one is available, else empty string."""
471471
if model_specs.hosting_eula_key is None:
472472
return ""
473473
return (

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -179,26 +179,50 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs):
179179

180180

181181
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
182-
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(patched_get_model_specs):
182+
def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs):
183+
184+
patched_get_model_specs.side_effect = get_special_model_spec
185+
186+
model_id = "gemma-model-1-artifact"
187+
region = "us-west-2"
188+
189+
assert {
190+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
191+
"huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz"
192+
} == environment_variables.retrieve_default(
193+
region=region,
194+
model_id=model_id,
195+
model_version="*",
196+
include_aws_sdk_env_vars=False,
197+
sagemaker_session=mock_session,
198+
instance_type="ml.p3.2xlarge",
199+
script="training",
200+
)
201+
202+
203+
@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER")
204+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
205+
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(
206+
patched_get_model_specs, patched_jumpstart_logger
207+
):
183208

184209
patched_get_model_specs.side_effect = get_special_model_spec
185210

186211
model_id = "gemma-model"
187212
region = "us-west-2"
188213

189-
# assert that unsupported instance types raise an exception
190-
with pytest.raises(ValueError) as e:
191-
environment_variables.retrieve_default(
192-
region=region,
193-
model_id=model_id,
194-
model_version="*",
195-
include_aws_sdk_env_vars=False,
196-
sagemaker_session=mock_session,
197-
instance_type="ml.p3.2xlarge",
198-
script="training",
199-
)
200-
assert (
201-
str(e.value) == "'gemma-model' does not support ml.p3.2xlarge instance type for "
214+
assert {} == environment_variables.retrieve_default(
215+
region=region,
216+
model_id=model_id,
217+
model_version="*",
218+
include_aws_sdk_env_vars=False,
219+
sagemaker_session=mock_session,
220+
instance_type="ml.p3.2xlarge",
221+
script="training",
222+
)
223+
224+
patched_jumpstart_logger.warning.assert_called_once_with(
225+
"'gemma-model' does not support ml.p3.2xlarge instance type for "
202226
"training. Please use one of the following instance types: "
203227
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
204228
)

0 commit comments

Comments
 (0)