Skip to content

Commit 8eee1d8

Browse files
authored
feat: support config_name in all JumpStart interfaces (#4583) (#4607)
* add-config-name * address comments * updates for set config * docstyle * updates * fix * format * format * remove tests
1 parent f026f02 commit 8eee1d8

40 files changed

+691
-92
lines changed

src/sagemaker/accept_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def retrieve_default(
7777
tolerate_deprecated_model: bool = False,
7878
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7979
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
80+
config_name: Optional[str] = None,
8081
) -> str:
8182
"""Retrieves the default accept type for the model matching the given arguments.
8283
@@ -98,6 +99,7 @@ def retrieve_default(
9899
object, used for SageMaker interactions. If not
99100
specified, one is created using the default AWS configuration
100101
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
102+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
101103
Returns:
102104
str: The default accept type to use for the model.
103105
@@ -117,4 +119,5 @@ def retrieve_default(
117119
tolerate_deprecated_model,
118120
sagemaker_session=sagemaker_session,
119121
model_type=model_type,
122+
config_name=config_name,
120123
)

src/sagemaker/content_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def retrieve_default(
7777
tolerate_deprecated_model: bool = False,
7878
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
7979
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
80+
config_name: Optional[str] = None,
8081
) -> str:
8182
"""Retrieves the default content type for the model matching the given arguments.
8283
@@ -98,6 +99,7 @@ def retrieve_default(
9899
object, used for SageMaker interactions. If not
99100
specified, one is created using the default AWS configuration
100101
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
102+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
101103
Returns:
102104
str: The default content type to use for the model.
103105
@@ -117,6 +119,7 @@ def retrieve_default(
117119
tolerate_deprecated_model,
118120
sagemaker_session=sagemaker_session,
119121
model_type=model_type,
122+
config_name=config_name,
120123
)
121124

122125

src/sagemaker/deserializers.py

+3
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def retrieve_default(
9797
tolerate_deprecated_model: bool = False,
9898
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
9999
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
100+
config_name: Optional[str] = None,
100101
) -> BaseDeserializer:
101102
"""Retrieves the default deserializer for the model matching the given arguments.
102103
@@ -118,6 +119,7 @@ def retrieve_default(
118119
object, used for SageMaker interactions. If not
119120
specified, one is created using the default AWS configuration
120121
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
122+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
121123
Returns:
122124
BaseDeserializer: The default deserializer to use for the model.
123125
@@ -138,4 +140,5 @@ def retrieve_default(
138140
tolerate_deprecated_model,
139141
sagemaker_session=sagemaker_session,
140142
model_type=model_type,
143+
config_name=config_name,
141144
)

src/sagemaker/environment_variables.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def retrieve_default(
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3737
instance_type: Optional[str] = None,
3838
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
39+
config_name: Optional[str] = None,
3940
) -> Dict[str, str]:
4041
"""Retrieves the default container environment variables for the model matching the arguments.
4142
@@ -65,6 +66,7 @@ def retrieve_default(
6566
variables specific for the instance type.
6667
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
6768
variables.
69+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6870
Returns:
6971
dict: The variables to use for the model.
7072
@@ -87,4 +89,5 @@ def retrieve_default(
8789
sagemaker_session=sagemaker_session,
8890
instance_type=instance_type,
8991
script=script,
92+
config_name=config_name,
9093
)

src/sagemaker/hyperparameters.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def retrieve_default(
3636
tolerate_vulnerable_model: bool = False,
3737
tolerate_deprecated_model: bool = False,
3838
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39+
config_name: Optional[str] = None,
3940
) -> Dict[str, str]:
4041
"""Retrieves the default training hyperparameters for the model matching the given arguments.
4142
@@ -66,6 +67,7 @@ def retrieve_default(
6667
object, used for SageMaker interactions. If not
6768
specified, one is created using the default AWS configuration
6869
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
70+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6971
Returns:
7072
dict: The hyperparameters to use for the model.
7173
@@ -86,6 +88,7 @@ def retrieve_default(
8688
tolerate_vulnerable_model=tolerate_vulnerable_model,
8789
tolerate_deprecated_model=tolerate_deprecated_model,
8890
sagemaker_session=sagemaker_session,
91+
config_name=config_name,
8992
)
9093

9194

src/sagemaker/image_uris.py

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def retrieve(
6868
inference_tool=None,
6969
serverless_inference_config=None,
7070
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
71+
config_name=None,
7172
) -> str:
7273
"""Retrieves the ECR URI for the Docker image matching the given arguments.
7374
@@ -121,6 +122,7 @@ def retrieve(
121122
object, used for SageMaker interactions. If not
122123
specified, one is created using the default AWS configuration
123124
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
125+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
124126
125127
Returns:
126128
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -160,6 +162,7 @@ def retrieve(
160162
tolerate_vulnerable_model,
161163
tolerate_deprecated_model,
162164
sagemaker_session=sagemaker_session,
165+
config_name=config_name,
163166
)
164167

165168
if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):

src/sagemaker/instance_types.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def retrieve_default(
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3737
training_instance_type: Optional[str] = None,
3838
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
39+
config_name: Optional[str] = None,
3940
) -> str:
4041
"""Retrieves the default instance type for the model matching the given arguments.
4142
@@ -64,6 +65,7 @@ def retrieve_default(
6465
Optionally supply this to get a inference instance type conditioned
6566
on the training instance, to ensure compatability of training artifact to inference
6667
instance. (Default: None).
68+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6769
Returns:
6870
str: The default instance type to use for the model.
6971
@@ -88,6 +90,7 @@ def retrieve_default(
8890
sagemaker_session=sagemaker_session,
8991
training_instance_type=training_instance_type,
9092
model_type=model_type,
93+
config_name=config_name,
9194
)
9295

9396

src/sagemaker/jumpstart/artifacts/environment_variables.py

+7
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4040
instance_type: Optional[str] = None,
4141
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
42+
config_name: Optional[str] = None,
4243
) -> Dict[str, str]:
4344
"""Retrieves the inference environment variables for the model matching the given arguments.
4445
@@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
6869
environment variables specific for the instance type.
6970
script (JumpStartScriptScope): The JumpStart script for which to retrieve
7071
environment variables.
72+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
7173
Returns:
7274
dict: the inference environment variables to use for the model.
7375
"""
@@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
8486
tolerate_vulnerable_model=tolerate_vulnerable_model,
8587
tolerate_deprecated_model=tolerate_deprecated_model,
8688
sagemaker_session=sagemaker_session,
89+
config_name=config_name,
8790
)
8891

8992
default_environment_variables: Dict[str, str] = {}
@@ -121,6 +124,7 @@ def _retrieve_default_environment_variables(
121124
tolerate_deprecated_model=tolerate_deprecated_model,
122125
sagemaker_session=sagemaker_session,
123126
instance_type=instance_type,
127+
config_name=config_name,
124128
)
125129
)
126130

@@ -167,6 +171,7 @@ def _retrieve_gated_model_uri_env_var_value(
167171
tolerate_deprecated_model: bool = False,
168172
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
169173
instance_type: Optional[str] = None,
174+
config_name: Optional[str] = None,
170175
) -> Optional[str]:
171176
"""Retrieves the gated model env var URI matching the given arguments.
172177
@@ -190,6 +195,7 @@ def _retrieve_gated_model_uri_env_var_value(
190195
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
191196
instance_type (str): An instance type to optionally supply in order to get
192197
environment variables specific for the instance type.
198+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
193199
194200
Returns:
195201
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -211,6 +217,7 @@ def _retrieve_gated_model_uri_env_var_value(
211217
tolerate_vulnerable_model=tolerate_vulnerable_model,
212218
tolerate_deprecated_model=tolerate_deprecated_model,
213219
sagemaker_session=sagemaker_session,
220+
config_name=config_name,
214221
)
215222

216223
s3_key: Optional[str] = (

src/sagemaker/jumpstart/artifacts/hyperparameters.py

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters(
3636
tolerate_deprecated_model: bool = False,
3737
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3838
instance_type: Optional[str] = None,
39+
config_name: Optional[str] = None,
3940
):
4041
"""Retrieves the training hyperparameters for the model matching the given arguments.
4142
@@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters(
6667
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
6768
instance_type (str): An instance type to optionally supply in order to get hyperparameters
6869
specific for the instance type.
70+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
6971
Returns:
7072
dict: the hyperparameters to use for the model.
7173
"""
@@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters(
8284
tolerate_vulnerable_model=tolerate_vulnerable_model,
8385
tolerate_deprecated_model=tolerate_deprecated_model,
8486
sagemaker_session=sagemaker_session,
87+
config_name=config_name,
8588
)
8689

8790
default_hyperparameters: Dict[str, str] = {}

src/sagemaker/jumpstart/artifacts/image_uris.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _retrieve_image_uri(
4646
tolerate_vulnerable_model: bool = False,
4747
tolerate_deprecated_model: bool = False,
4848
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
49+
config_name: Optional[str] = None,
4950
):
5051
"""Retrieves the container image URI for JumpStart models.
5152
@@ -95,6 +96,7 @@ def _retrieve_image_uri(
9596
object, used for SageMaker interactions. If not
9697
specified, one is created using the default AWS configuration
9798
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
99+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
98100
Returns:
99101
str: the ECR URI for the corresponding SageMaker Docker image.
100102
@@ -116,6 +118,7 @@ def _retrieve_image_uri(
116118
tolerate_vulnerable_model=tolerate_vulnerable_model,
117119
tolerate_deprecated_model=tolerate_deprecated_model,
118120
sagemaker_session=sagemaker_session,
121+
config_name=config_name,
119122
)
120123

121124
if image_scope == JumpStartScriptScope.INFERENCE:
@@ -200,4 +203,5 @@ def _retrieve_image_uri(
200203
distribution=distribution,
201204
base_framework_version=base_framework_version_override or base_framework_version,
202205
training_compiler_config=training_compiler_config,
206+
config_name=config_name,
203207
)

src/sagemaker/jumpstart/artifacts/incremental_training.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _model_supports_incremental_training(
3333
tolerate_vulnerable_model: bool = False,
3434
tolerate_deprecated_model: bool = False,
3535
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
config_name: Optional[str] = None,
3637
) -> bool:
3738
"""Returns True if the model supports incremental training.
3839
@@ -54,6 +55,7 @@ def _model_supports_incremental_training(
5455
object, used for SageMaker interactions. If not
5556
specified, one is created using the default AWS configuration
5657
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
58+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
5759
Returns:
5860
bool: the support status for incremental training.
5961
"""
@@ -70,6 +72,7 @@ def _model_supports_incremental_training(
7072
tolerate_vulnerable_model=tolerate_vulnerable_model,
7173
tolerate_deprecated_model=tolerate_deprecated_model,
7274
sagemaker_session=sagemaker_session,
75+
config_name=config_name,
7376
)
7477

7578
return model_specs.supports_incremental_training()

src/sagemaker/jumpstart/artifacts/instance_types.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _retrieve_default_instance_type(
4040
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4141
training_instance_type: Optional[str] = None,
4242
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
43+
config_name: Optional[str] = None,
4344
) -> str:
4445
"""Retrieves the default instance type for the model.
4546
@@ -68,6 +69,7 @@ def _retrieve_default_instance_type(
6869
Optionally supply this to get a inference instance type conditioned
6970
on the training instance, to ensure compatability of training artifact to inference
7071
instance. (Default: None).
72+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
7173
Returns:
7274
str: the default instance type to use for the model or None.
7375
@@ -89,6 +91,7 @@ def _retrieve_default_instance_type(
8991
tolerate_deprecated_model=tolerate_deprecated_model,
9092
model_type=model_type,
9193
sagemaker_session=sagemaker_session,
94+
config_name=config_name,
9295
)
9396

9497
if scope == JumpStartScriptScope.INFERENCE:
@@ -128,6 +131,7 @@ def _retrieve_instance_types(
128131
tolerate_deprecated_model: bool = False,
129132
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
130133
training_instance_type: Optional[str] = None,
134+
config_name: Optional[str] = None,
131135
) -> List[str]:
132136
"""Retrieves the supported instance types for the model.
133137
@@ -156,6 +160,7 @@ def _retrieve_instance_types(
156160
Optionally supply this to get a inference instance type conditioned
157161
on the training instance, to ensure compatability of training artifact to inference
158162
instance. (Default: None).
163+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
159164
Returns:
160165
list: the supported instance types to use for the model or None.
161166
@@ -176,6 +181,7 @@ def _retrieve_instance_types(
176181
tolerate_vulnerable_model=tolerate_vulnerable_model,
177182
tolerate_deprecated_model=tolerate_deprecated_model,
178183
sagemaker_session=sagemaker_session,
184+
config_name=config_name,
179185
)
180186

181187
if scope == JumpStartScriptScope.INFERENCE:

0 commit comments

Comments
 (0)