@@ -37,6 +37,7 @@ def _retrieve_default_instance_type(
37
37
tolerate_vulnerable_model : bool = False ,
38
38
tolerate_deprecated_model : bool = False ,
39
39
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
40
+ training_instance_type : Optional [str ] = None ,
40
41
) -> str :
41
42
"""Retrieves the default instance type for the model.
42
43
@@ -60,6 +61,11 @@ def _retrieve_default_instance_type(
60
61
object, used for SageMaker interactions. If not
61
62
specified, one is created using the default AWS configuration
62
63
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64
+ training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
65
+ instance type used for the training job that produced the fine-tuned weights.
66
+ Optionally supply this to get a inference instance type conditioned
67
+ on the training instance, to ensure compatability of training artifact to inference
68
+ instance. (Default: None).
63
69
Returns:
64
70
str: the default instance type to use for the model or None.
65
71
@@ -82,7 +88,21 @@ def _retrieve_default_instance_type(
82
88
)
83
89
84
90
if scope == JumpStartScriptScope .INFERENCE :
85
- default_instance_type = model_specs .default_inference_instance_type
91
+ instance_specific_default_instance_type = (
92
+ (
93
+ model_specs .training_instance_type_variants .get_instance_specific_default_inference_instance_type ( # pylint: disable=C0301 # noqa: E501
94
+ training_instance_type
95
+ )
96
+ )
97
+ if training_instance_type is not None
98
+ and getattr (model_specs , "training_instance_type_variants" , None ) is not None
99
+ else None
100
+ )
101
+ default_instance_type = (
102
+ instance_specific_default_instance_type
103
+ if instance_specific_default_instance_type is not None
104
+ else model_specs .default_inference_instance_type
105
+ )
86
106
elif scope == JumpStartScriptScope .TRAINING :
87
107
default_instance_type = model_specs .default_training_instance_type
88
108
else :
@@ -103,6 +123,7 @@ def _retrieve_instance_types(
103
123
tolerate_vulnerable_model : bool = False ,
104
124
tolerate_deprecated_model : bool = False ,
105
125
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
126
+ training_instance_type : Optional [str ] = None ,
106
127
) -> List [str ]:
107
128
"""Retrieves the supported instance types for the model.
108
129
@@ -126,6 +147,11 @@ def _retrieve_instance_types(
126
147
object, used for SageMaker interactions. If not
127
148
specified, one is created using the default AWS configuration
128
149
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150
+ training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
151
+ instance type used for the training job that produced the fine-tuned weights.
152
+ Optionally supply this to get a inference instance type conditioned
153
+ on the training instance, to ensure compatability of training artifact to inference
154
+ instance. (Default: None).
129
155
Returns:
130
156
list: the supported instance types to use for the model or None.
131
157
@@ -148,8 +174,24 @@ def _retrieve_instance_types(
148
174
)
149
175
150
176
if scope == JumpStartScriptScope .INFERENCE :
151
- instance_types = model_specs .supported_inference_instance_types
177
+ default_instance_types = model_specs .supported_inference_instance_types or []
178
+ instance_specific_instance_types = (
179
+ model_specs .training_instance_type_variants .get_instance_specific_supported_inference_instance_types ( # pylint: disable=C0301 # noqa: E501
180
+ training_instance_type
181
+ )
182
+ if training_instance_type is not None
183
+ and getattr (model_specs , "training_instance_type_variants" , None ) is not None
184
+ else []
185
+ )
186
+ instance_types = (
187
+ instance_specific_instance_types
188
+ if len (instance_specific_instance_types ) > 0
189
+ else default_instance_types
190
+ )
191
+
152
192
elif scope == JumpStartScriptScope .TRAINING :
193
+ if training_instance_type is not None :
194
+ raise ValueError ("Cannot use `training_instance_type` argument " "with training scope." )
153
195
instance_types = model_specs .supported_training_instance_types
154
196
else :
155
197
raise NotImplementedError (
0 commit comments