@@ -40,6 +40,7 @@ def _retrieve_default_environment_variables(
40
40
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
41
41
instance_type : Optional [str ] = None ,
42
42
script : JumpStartScriptScope = JumpStartScriptScope .INFERENCE ,
43
+ config_name : Optional [str ] = None ,
43
44
) -> Dict [str , str ]:
44
45
"""Retrieves the inference environment variables for the model matching the given arguments.
45
46
@@ -71,6 +72,7 @@ def _retrieve_default_environment_variables(
71
72
environment variables specific for the instance type.
72
73
script (JumpStartScriptScope): The JumpStart script for which to retrieve
73
74
environment variables.
75
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
74
76
Returns:
75
77
dict: the inference environment variables to use for the model.
76
78
"""
@@ -88,6 +90,7 @@ def _retrieve_default_environment_variables(
88
90
tolerate_vulnerable_model = tolerate_vulnerable_model ,
89
91
tolerate_deprecated_model = tolerate_deprecated_model ,
90
92
sagemaker_session = sagemaker_session ,
93
+ config_name = config_name ,
91
94
)
92
95
93
96
default_environment_variables : Dict [str , str ] = {}
@@ -126,6 +129,7 @@ def _retrieve_default_environment_variables(
126
129
tolerate_deprecated_model = tolerate_deprecated_model ,
127
130
sagemaker_session = sagemaker_session ,
128
131
instance_type = instance_type ,
132
+ config_name = config_name ,
129
133
)
130
134
)
131
135
@@ -173,6 +177,7 @@ def _retrieve_gated_model_uri_env_var_value(
173
177
tolerate_deprecated_model : bool = False ,
174
178
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
175
179
instance_type : Optional [str ] = None ,
180
+ config_name : Optional [str ] = None ,
176
181
) -> Optional [str ]:
177
182
"""Retrieves the gated model env var URI matching the given arguments.
178
183
@@ -198,6 +203,7 @@ def _retrieve_gated_model_uri_env_var_value(
198
203
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
199
204
instance_type (str): An instance type to optionally supply in order to get
200
205
environment variables specific for the instance type.
206
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
201
207
202
208
Returns:
203
209
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -220,6 +226,7 @@ def _retrieve_gated_model_uri_env_var_value(
220
226
tolerate_vulnerable_model = tolerate_vulnerable_model ,
221
227
tolerate_deprecated_model = tolerate_deprecated_model ,
222
228
sagemaker_session = sagemaker_session ,
229
+ config_name = config_name ,
223
230
)
224
231
225
232
s3_key : Optional [str ] = (
0 commit comments