@@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
39
39
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
40
40
instance_type : Optional [str ] = None ,
41
41
script : JumpStartScriptScope = JumpStartScriptScope .INFERENCE ,
42
+ config_name : Optional [str ] = None ,
42
43
) -> Dict [str , str ]:
43
44
"""Retrieves the inference environment variables for the model matching the given arguments.
44
45
@@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
68
69
environment variables specific for the instance type.
69
70
script (JumpStartScriptScope): The JumpStart script for which to retrieve
70
71
environment variables.
72
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
71
73
Returns:
72
74
dict: the inference environment variables to use for the model.
73
75
"""
@@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
84
86
tolerate_vulnerable_model = tolerate_vulnerable_model ,
85
87
tolerate_deprecated_model = tolerate_deprecated_model ,
86
88
sagemaker_session = sagemaker_session ,
89
+ config_name = config_name ,
87
90
)
88
91
89
92
default_environment_variables : Dict [str , str ] = {}
@@ -121,6 +124,7 @@ def _retrieve_default_environment_variables(
121
124
tolerate_deprecated_model = tolerate_deprecated_model ,
122
125
sagemaker_session = sagemaker_session ,
123
126
instance_type = instance_type ,
127
+ config_name = config_name ,
124
128
)
125
129
)
126
130
@@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value(
167
171
tolerate_deprecated_model : bool = False ,
168
172
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
169
173
instance_type : Optional [str ] = None ,
174
+ config_name : Optional [str ] = None ,
170
175
) -> Optional [str ]:
171
176
"""Retrieves the gated model env var URI matching the given arguments.
172
177
@@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value(
190
195
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
191
196
instance_type (str): An instance type to optionally supply in order to get
192
197
environment variables specific for the instance type.
198
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
193
199
194
200
Returns:
195
201
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value(
211
217
tolerate_vulnerable_model = tolerate_vulnerable_model ,
212
218
tolerate_deprecated_model = tolerate_deprecated_model ,
213
219
sagemaker_session = sagemaker_session ,
220
+ config_name = config_name ,
214
221
)
215
222
216
223
s3_key : Optional [str ] = (
0 commit comments