12
12
# language governing permissions and limitations under the License.
13
13
"""This module contains functions for obtaining JumpStart environment variables."""
14
14
from __future__ import absolute_import
15
- from typing import Dict , Optional
15
+ from typing import Callable , Dict , Optional , Set
16
16
from sagemaker .jumpstart .constants import (
17
17
DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
18
18
JUMPSTART_DEFAULT_REGION_NAME ,
19
+ JUMPSTART_LOGGER ,
19
20
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY ,
20
21
)
21
22
from sagemaker .jumpstart .enums import (
@@ -110,7 +111,9 @@ def _retrieve_default_environment_variables(
110
111
111
112
default_environment_variables .update (instance_specific_environment_variables )
112
113
113
- gated_model_env_var : Optional [str ] = _retrieve_gated_model_uri_env_var_value (
114
+ retrieve_gated_env_var_for_instance_type : Callable [
115
+ [str ], Optional [str ]
116
+ ] = lambda instance_type : _retrieve_gated_model_uri_env_var_value (
114
117
model_id = model_id ,
115
118
model_version = model_version ,
116
119
region = region ,
@@ -120,12 +123,32 @@ def _retrieve_default_environment_variables(
120
123
instance_type = instance_type ,
121
124
)
122
125
126
+ gated_model_env_var : Optional [str ] = retrieve_gated_env_var_for_instance_type (
127
+ instance_type
128
+ )
129
+
123
130
if gated_model_env_var is None and model_specs .is_gated_model ():
124
- raise ValueError (
125
- f"'{ model_id } ' does not support { instance_type } instance type for training. "
126
- "Please use one of the following instance types: "
127
- f"{ ', ' .join (model_specs .supported_training_instance_types )} ."
128
- )
131
+
132
+ possible_env_vars : Set [str ] = {
133
+ retrieve_gated_env_var_for_instance_type (instance_type )
134
+ for instance_type in model_specs .supported_training_instance_types
135
+ }
136
+
137
+ # If all officially supported instance types have the same underlying artifact,
138
+ # we can use this artifact with high confidence that it'll succeed with
139
+ # an arbitrary instance.
140
+ if len (possible_env_vars ) == 1 :
141
+ gated_model_env_var = list (possible_env_vars )[0 ]
142
+
143
+ # If this model does not have 1 artifact for all supported instance types,
144
+ # we cannot determine which artifact to use for an arbitrary instance.
145
+ else :
146
+ log_msg = (
147
+ f"'{ model_id } ' does not support { instance_type } instance type"
148
+ " for training. Please use one of the following instance types: "
149
+ f"{ ', ' .join (model_specs .supported_training_instance_types )} ."
150
+ )
151
+ JUMPSTART_LOGGER .warning (log_msg )
129
152
130
153
if gated_model_env_var is not None :
131
154
default_environment_variables .update (
0 commit comments