@@ -1119,6 +1119,7 @@ def optimize(
1119
1119
quantization_config : Optional [Dict ] = None ,
1120
1120
compilation_config : Optional [Dict ] = None ,
1121
1121
speculative_decoding_config : Optional [Dict ] = None ,
1122
+ sharding_config : Optional [Dict ] = None ,
1122
1123
env_vars : Optional [Dict ] = None ,
1123
1124
vpc_config : Optional [Dict ] = None ,
1124
1125
kms_key : Optional [str ] = None ,
@@ -1142,6 +1143,8 @@ def optimize(
1142
1143
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1143
1144
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1144
1145
Defaults to ``None``
1146
+ sharding_config (Optional[Dict]): Model sharding configuration.
1147
+ Defaults to ``None``
1145
1148
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1146
1149
container. Defaults to ``None``.
1147
1150
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1170,6 +1173,7 @@ def optimize(
1170
1173
quantization_config = quantization_config ,
1171
1174
compilation_config = compilation_config ,
1172
1175
speculative_decoding_config = speculative_decoding_config ,
1176
+ sharding_config = sharding_config ,
1173
1177
env_vars = env_vars ,
1174
1178
vpc_config = vpc_config ,
1175
1179
kms_key = kms_key ,
@@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
1189
1193
quantization_config : Optional [Dict ] = None ,
1190
1194
compilation_config : Optional [Dict ] = None ,
1191
1195
speculative_decoding_config : Optional [Dict ] = None ,
1196
+ sharding_config : Optional [Dict ] = None ,
1192
1197
env_vars : Optional [Dict ] = None ,
1193
1198
vpc_config : Optional [Dict ] = None ,
1194
1199
kms_key : Optional [str ] = None ,
@@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
1212
1217
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1213
1218
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1214
1219
Defaults to ``None``
1220
+ sharding_config (Optional[Dict]): Model sharding configuration.
1221
+ Defaults to ``None``
1215
1222
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1216
1223
container. Defaults to ``None``.
1217
1224
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
1238
1245
if quantization_config and compilation_config :
1239
1246
raise ValueError ("Quantization config and compilation config are mutually exclusive." )
1240
1247
1248
+ if sharding_config and (quantization_config or compilation_config or speculative_decoding_config ):
1249
+ raise ValueError ("Sharding config is mutually exclusive and cannot be combined with any other optimization." )
1250
+
1251
+ if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars ) or (sharding_config .get ("OverrideEnvironment" ) and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config ["OverrideEnvironment" ])):
1252
+ raise ValueError ("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config." )
1253
+
1241
1254
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1242
1255
self .instance_type = instance_type or self .instance_type
1243
1256
self .role_arn = role_arn or self .role_arn
@@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
1254
1267
quantization_config = quantization_config ,
1255
1268
compilation_config = compilation_config ,
1256
1269
speculative_decoding_config = speculative_decoding_config ,
1270
+ sharding_config = sharding_config ,
1257
1271
env_vars = env_vars ,
1258
1272
vpc_config = vpc_config ,
1259
1273
kms_key = kms_key ,
@@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
1272
1286
quantization_config = quantization_config ,
1273
1287
compilation_config = compilation_config ,
1274
1288
speculative_decoding_config = speculative_decoding_config ,
1289
+ sharding_config = sharding_config ,
1275
1290
env_vars = env_vars ,
1276
1291
vpc_config = vpc_config ,
1277
1292
kms_key = kms_key ,
@@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
1287
1302
if not speculative_decoding_config :
1288
1303
self .pysdk_model .remove_tag_with_key (Tag .SPECULATIVE_DRAFT_MODEL_PROVIDER )
1289
1304
1305
+ if sharding_config :
1306
+ self .pysdk_model ._is_sharded_model = True
1307
+
1290
1308
return self .pysdk_model
1291
1309
1292
1310
def _optimize_for_hf (
@@ -1297,6 +1315,7 @@ def _optimize_for_hf(
1297
1315
quantization_config : Optional [Dict ] = None ,
1298
1316
compilation_config : Optional [Dict ] = None ,
1299
1317
speculative_decoding_config : Optional [Dict ] = None ,
1318
+ sharding_config : Optional [Dict ] = None ,
1300
1319
env_vars : Optional [Dict ] = None ,
1301
1320
vpc_config : Optional [Dict ] = None ,
1302
1321
kms_key : Optional [str ] = None ,
@@ -1312,6 +1331,8 @@ def _optimize_for_hf(
1312
1331
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1313
1332
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1314
1333
Defaults to ``None``
1334
+ sharding_config (Optional[Dict]): Model sharding configuration.
1335
+ Defaults to ``None``
1315
1336
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1316
1337
container. Defaults to ``None``.
1317
1338
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1327,7 +1348,7 @@ def _optimize_for_hf(
1327
1348
self .pysdk_model , speculative_decoding_config , False
1328
1349
)
1329
1350
1330
- if quantization_config or compilation_config :
1351
+ if quantization_config or compilation_config or sharding_config :
1331
1352
create_optimization_job_args = {
1332
1353
"OptimizationJobName" : job_name ,
1333
1354
"DeploymentInstanceType" : self .instance_type ,
0 commit comments