Skip to content

Commit dff2ccd

Browse files
committed
Merge remote-tracking branch 'origin' into fix/sagemaker-session-region-not-being-used
2 parents 8576b32 + 064378d commit dff2ccd

File tree

8 files changed

+1388
-50
lines changed

8 files changed

+1388
-50
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart environment variables."""
1414
from __future__ import absolute_import
15-
from typing import Dict, Optional
15+
from typing import Callable, Dict, Optional, Set
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18+
JUMPSTART_LOGGER,
1819
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
1920
)
2021
from sagemaker.jumpstart.enums import (
@@ -111,7 +112,9 @@ def _retrieve_default_environment_variables(
111112

112113
default_environment_variables.update(instance_specific_environment_variables)
113114

114-
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
115+
retrieve_gated_env_var_for_instance_type: Callable[
116+
[str], Optional[str]
117+
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
115118
model_id=model_id,
116119
model_version=model_version,
117120
region=region,
@@ -121,6 +124,33 @@ def _retrieve_default_environment_variables(
121124
instance_type=instance_type,
122125
)
123126

127+
gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type(
128+
instance_type
129+
)
130+
131+
if gated_model_env_var is None and model_specs.is_gated_model():
132+
133+
possible_env_vars: Set[str] = {
134+
retrieve_gated_env_var_for_instance_type(instance_type)
135+
for instance_type in model_specs.supported_training_instance_types
136+
}
137+
138+
# If all officially supported instance types have the same underlying artifact,
139+
# we can use this artifact with high confidence that it'll succeed with
140+
# an arbitrary instance.
141+
if len(possible_env_vars) == 1:
142+
gated_model_env_var = list(possible_env_vars)[0]
143+
144+
# If this model does not have 1 artifact for all supported instance types,
145+
# we cannot determine which artifact to use for an arbitrary instance.
146+
else:
147+
log_msg = (
148+
f"'{model_id}' does not support {instance_type} instance type"
149+
" for training. Please use one of the following instance types: "
150+
f"{', '.join(model_specs.supported_training_instance_types)}."
151+
)
152+
JUMPSTART_LOGGER.warning(log_msg)
153+
124154
if gated_model_env_var is not None:
125155
default_environment_variables.update(
126156
{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: gated_model_env_var}

src/sagemaker/jumpstart/factory/estimator.py

+21
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
)
6363
from sagemaker.jumpstart.utils import (
6464
add_jumpstart_model_id_version_tags,
65+
get_eula_message,
6566
update_dict_if_key_not_present,
6667
resolve_estimator_sagemaker_config_field,
6768
verify_model_region_and_return_specs,
@@ -601,6 +602,26 @@ def _add_env_to_kwargs(
601602
value,
602603
)
603604

605+
environment = getattr(kwargs, "environment", {}) or {}
606+
if (
607+
environment.get(SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY)
608+
and str(environment.get("accept_eula", "")).lower() != "true"
609+
):
610+
model_specs = verify_model_region_and_return_specs(
611+
model_id=kwargs.model_id,
612+
version=kwargs.model_version,
613+
region=kwargs.region,
614+
scope=JumpStartScriptScope.TRAINING,
615+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
616+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
617+
sagemaker_session=kwargs.sagemaker_session,
618+
)
619+
if model_specs.is_gated_model():
620+
raise ValueError(
621+
"Need to define ‘accept_eula'='true' within Environment. "
622+
f"{get_eula_message(model_specs, kwargs.region)}"
623+
)
624+
604625
return kwargs
605626

606627

src/sagemaker/jumpstart/types.py

+4
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,10 @@ def use_training_model_artifact(self) -> bool:
963963
# otherwise, return true is a training model package is not set
964964
return len(self.training_model_package_artifact_uris or {}) == 0
965965

966+
def is_gated_model(self) -> bool:
967+
"""Returns True if the model has a EULA key or the model bucket is gated."""
968+
return self.gated_bucket or self.hosting_eula_key is not None
969+
966970
def supports_incremental_training(self) -> bool:
967971
"""Returns True if the model supports incremental training."""
968972
return self.incremental_training_supported

src/sagemaker/jumpstart/utils.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -476,21 +476,25 @@ def update_inference_tags_with_jumpstart_training_tags(
476476
return inference_tags
477477

478478

479+
def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
480+
"""Returns EULA message to display if one is available, else empty string."""
481+
if model_specs.hosting_eula_key is None:
482+
return ""
483+
return (
484+
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
485+
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
486+
f"amazonaws.com{'.cn' if region.startswith('cn-') else ''}"
487+
f"/{model_specs.hosting_eula_key} for terms of use."
488+
)
489+
490+
479491
def emit_logs_based_on_model_specs(
480492
model_specs: JumpStartModelSpecs, region: str, s3_client: boto3.client
481493
) -> None:
482494
"""Emits logs based on model specs and region."""
483495

484496
if model_specs.hosting_eula_key:
485-
constants.JUMPSTART_LOGGER.info(
486-
"Model '%s' requires accepting end-user license agreement (EULA). "
487-
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
488-
model_specs.model_id,
489-
get_jumpstart_content_bucket(region=region),
490-
region,
491-
".cn" if region.startswith("cn-") else "",
492-
model_specs.hosting_eula_key,
493-
)
497+
constants.JUMPSTART_LOGGER.info(get_eula_message(model_specs, region))
494498

495499
full_version: str = model_specs.version
496500

tests/unit/sagemaker/environment_variables/jumpstart/test_default.py

+65
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker import environment_variables
21+
from sagemaker.jumpstart.utils import get_jumpstart_gated_content_bucket
2122
from sagemaker.jumpstart.enums import JumpStartModelType
2223

2324
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
@@ -204,6 +205,70 @@ def test_jumpstart_sdk_environment_variables(
204205
)
205206

206207

208+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
209+
def test_jumpstart_sdk_environment_variables_1_artifact_all_variants(patched_get_model_specs):
210+
211+
patched_get_model_specs.side_effect = get_special_model_spec
212+
213+
model_id = "gemma-model-1-artifact"
214+
region = "us-west-2"
215+
216+
assert {
217+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
218+
"huggingface-training/train-huggingface-llm-gemma-7b-instruct.tar.gz"
219+
} == environment_variables.retrieve_default(
220+
region=region,
221+
model_id=model_id,
222+
model_version="*",
223+
include_aws_sdk_env_vars=False,
224+
sagemaker_session=mock_session,
225+
instance_type="ml.p3.2xlarge",
226+
script="training",
227+
)
228+
229+
230+
@patch("sagemaker.jumpstart.artifacts.environment_variables.JUMPSTART_LOGGER")
231+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
232+
def test_jumpstart_sdk_environment_variables_no_gated_env_var_available(
233+
patched_get_model_specs, patched_jumpstart_logger
234+
):
235+
236+
patched_get_model_specs.side_effect = get_special_model_spec
237+
238+
model_id = "gemma-model"
239+
region = "us-west-2"
240+
241+
assert {} == environment_variables.retrieve_default(
242+
region=region,
243+
model_id=model_id,
244+
model_version="*",
245+
include_aws_sdk_env_vars=False,
246+
sagemaker_session=mock_session,
247+
instance_type="ml.p3.2xlarge",
248+
script="training",
249+
)
250+
251+
patched_jumpstart_logger.warning.assert_called_once_with(
252+
"'gemma-model' does not support ml.p3.2xlarge instance type for "
253+
"training. Please use one of the following instance types: "
254+
"ml.g5.12xlarge, ml.g5.24xlarge, ml.g5.48xlarge, ml.p4d.24xlarge."
255+
)
256+
257+
# assert that supported instance types succeed
258+
assert {
259+
"SageMakerGatedModelS3Uri": f"s3://{get_jumpstart_gated_content_bucket(region)}/"
260+
"huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-7b-instruct.tar.gz"
261+
} == environment_variables.retrieve_default(
262+
region=region,
263+
model_id=model_id,
264+
model_version="*",
265+
include_aws_sdk_env_vars=False,
266+
sagemaker_session=mock_session,
267+
instance_type="ml.g5.24xlarge",
268+
script="training",
269+
)
270+
271+
207272
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
208273
def test_jumpstart_sdk_environment_variables_instance_type_overrides(patched_get_model_specs):
209274

0 commit comments

Comments
 (0)