Skip to content

Commit b19b4a3

Browse files
author
Ashish Gupta
committed
changes for blackbird - model sharding
1 parent 132fb94 commit b19b4a3

File tree

6 files changed

+97
-6
lines changed

6 files changed

+97
-6
lines changed

src/sagemaker/model.py

+6
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def __init__(
372372
self.endpoint_name = None
373373
self.inference_component_name = None
374374
self._is_compiled_model = False
375+
self._is_sharded_model = False
375376
self._compilation_job_name = None
376377
self._is_edge_packaged_model = False
377378
self.inference_recommender_job_results = None
@@ -1599,6 +1600,11 @@ def deploy(
15991600
if self._base_name is not None:
16001601
self._base_name = "-".join((self._base_name, compiled_model_suffix))
16011602

1603+
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1604+
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1605+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
1606+
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
1607+
16021608
# Support multiple models on same endpoint
16031609
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16041610
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def _optimize_for_jumpstart(
681681
quantization_config: Optional[Dict] = None,
682682
compilation_config: Optional[Dict] = None,
683683
speculative_decoding_config: Optional[Dict] = None,
684+
sharding_config: Optional[Dict] = None,
684685
env_vars: Optional[Dict] = None,
685686
vpc_config: Optional[Dict] = None,
686687
kms_key: Optional[str] = None,
@@ -702,6 +703,8 @@ def _optimize_for_jumpstart(
702703
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
703704
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
704705
Defaults to ``None``
706+
sharding_config (Optional[Dict]): Model sharding configuration.
707+
Defaults to ``None``
705708
env_vars (Optional[Dict]): Additional environment variables to run the optimization
706709
container. Defaults to ``None``.
707710
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -727,7 +730,7 @@ def _optimize_for_jumpstart(
727730
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
728731

729732
optimization_config, override_env = _extract_optimization_config_and_env(
730-
quantization_config, compilation_config
733+
quantization_config, compilation_config, sharding_config
731734
)
732735
if not optimization_config and is_compilation:
733736
override_env = override_env or pysdk_model_env_vars
@@ -792,7 +795,7 @@ def _optimize_for_jumpstart(
792795
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
793796
if optimization_env_vars:
794797
self.pysdk_model.env.update(optimization_env_vars)
795-
if quantization_config or is_compilation:
798+
if quantization_config or sharding_config or is_compilation:
796799
return create_optimization_job_args
797800
return None
798801

src/sagemaker/serve/builder/model_builder.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,7 @@ def optimize(
11191119
quantization_config: Optional[Dict] = None,
11201120
compilation_config: Optional[Dict] = None,
11211121
speculative_decoding_config: Optional[Dict] = None,
1122+
sharding_config: Optional[Dict] = None,
11221123
env_vars: Optional[Dict] = None,
11231124
vpc_config: Optional[Dict] = None,
11241125
kms_key: Optional[str] = None,
@@ -1142,6 +1143,8 @@ def optimize(
11421143
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
11431144
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
11441145
Defaults to ``None``
1146+
sharding_config (Optional[Dict]): Model sharding configuration.
1147+
Defaults to ``None``
11451148
env_vars (Optional[Dict]): Additional environment variables to run the optimization
11461149
container. Defaults to ``None``.
11471150
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1170,6 +1173,7 @@ def optimize(
11701173
quantization_config=quantization_config,
11711174
compilation_config=compilation_config,
11721175
speculative_decoding_config=speculative_decoding_config,
1176+
sharding_config=sharding_config,
11731177
env_vars=env_vars,
11741178
vpc_config=vpc_config,
11751179
kms_key=kms_key,
@@ -1189,6 +1193,7 @@ def _model_builder_optimize_wrapper(
11891193
quantization_config: Optional[Dict] = None,
11901194
compilation_config: Optional[Dict] = None,
11911195
speculative_decoding_config: Optional[Dict] = None,
1196+
sharding_config: Optional[Dict] = None,
11921197
env_vars: Optional[Dict] = None,
11931198
vpc_config: Optional[Dict] = None,
11941199
kms_key: Optional[str] = None,
@@ -1212,6 +1217,8 @@ def _model_builder_optimize_wrapper(
12121217
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
12131218
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
12141219
Defaults to ``None``
1220+
sharding_config (Optional[Dict]): Model sharding configuration.
1221+
Defaults to ``None``
12151222
env_vars (Optional[Dict]): Additional environment variables to run the optimization
12161223
container. Defaults to ``None``.
12171224
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1238,6 +1245,12 @@ def _model_builder_optimize_wrapper(
12381245
if quantization_config and compilation_config:
12391246
raise ValueError("Quantization config and compilation config are mutually exclusive.")
12401247

1248+
if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
1249+
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")
1250+
1251+
if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):
1252+
raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")
1253+
12411254
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12421255
self.instance_type = instance_type or self.instance_type
12431256
self.role_arn = role_arn or self.role_arn
@@ -1254,6 +1267,7 @@ def _model_builder_optimize_wrapper(
12541267
quantization_config=quantization_config,
12551268
compilation_config=compilation_config,
12561269
speculative_decoding_config=speculative_decoding_config,
1270+
sharding_config=sharding_config,
12571271
env_vars=env_vars,
12581272
vpc_config=vpc_config,
12591273
kms_key=kms_key,
@@ -1272,6 +1286,7 @@ def _model_builder_optimize_wrapper(
12721286
quantization_config=quantization_config,
12731287
compilation_config=compilation_config,
12741288
speculative_decoding_config=speculative_decoding_config,
1289+
sharding_config=sharding_config,
12751290
env_vars=env_vars,
12761291
vpc_config=vpc_config,
12771292
kms_key=kms_key,
@@ -1287,6 +1302,9 @@ def _model_builder_optimize_wrapper(
12871302
if not speculative_decoding_config:
12881303
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)
12891304

1305+
if sharding_config:
1306+
self.pysdk_model._is_sharded_model = True
1307+
12901308
return self.pysdk_model
12911309

12921310
def _optimize_for_hf(
@@ -1297,6 +1315,7 @@ def _optimize_for_hf(
12971315
quantization_config: Optional[Dict] = None,
12981316
compilation_config: Optional[Dict] = None,
12991317
speculative_decoding_config: Optional[Dict] = None,
1318+
sharding_config: Optional[Dict] = None,
13001319
env_vars: Optional[Dict] = None,
13011320
vpc_config: Optional[Dict] = None,
13021321
kms_key: Optional[str] = None,
@@ -1312,6 +1331,8 @@ def _optimize_for_hf(
13121331
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
13131332
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
13141333
Defaults to ``None``
1334+
sharding_config (Optional[Dict]): Model sharding configuration.
1335+
Defaults to ``None``
13151336
env_vars (Optional[Dict]): Additional environment variables to run the optimization
13161337
container. Defaults to ``None``.
13171338
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1327,7 +1348,7 @@ def _optimize_for_hf(
13271348
self.pysdk_model, speculative_decoding_config, False
13281349
)
13291350

1330-
if quantization_config or compilation_config:
1351+
if quantization_config or compilation_config or sharding_config:
13311352
create_optimization_job_args = {
13321353
"OptimizationJobName": job_name,
13331354
"DeploymentInstanceType": self.instance_type,

src/sagemaker/serve/utils/optimize_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,15 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
259259

260260

261261
def _extract_optimization_config_and_env(
262-
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
262+
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
263+
sharding_config: Optional[Dict] = None
263264
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
264265
"""Extracts optimization config and environment variables.
265266
266267
Args:
267268
quantization_config (Optional[Dict]): The quantization config.
268269
compilation_config (Optional[Dict]): The compilation config.
270+
sharding_config (Optional[Dict]): The sharding config.
269271
270272
Returns:
271273
Optional[Tuple[Optional[Dict], Optional[Dict]]]:
@@ -279,6 +281,10 @@ def _extract_optimization_config_and_env(
279281
return {"ModelCompilationConfig": compilation_config}, compilation_config.get(
280282
"OverrideEnvironment"
281283
)
284+
if sharding_config:
285+
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
286+
"OverrideEnvironment"
287+
)
282288
return None, None
283289

284290

tests/unit/sagemaker/serve/builder/test_model_builder.py

+34
Original file line numberDiff line numberDiff line change
@@ -2667,6 +2667,40 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
26672667
),
26682668
)
26692669

2670+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2671+
def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
2672+
mock_sagemaker_session = Mock()
2673+
model_builder = ModelBuilder(
2674+
model="meta-textgeneration-llama-3-70b",
2675+
sagemaker_session=mock_sagemaker_session,
2676+
)
2677+
2678+
self.assertRaisesRegex(
2679+
ValueError,
2680+
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
2681+
lambda: model_builder.optimize(
2682+
quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2683+
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2684+
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2685+
),
2686+
)
2687+
2688+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2689+
def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
2690+
mock_sagemaker_session = Mock()
2691+
model_builder = ModelBuilder(
2692+
model="meta-textgeneration-llama-3-70b",
2693+
sagemaker_session=mock_sagemaker_session,
2694+
)
2695+
2696+
self.assertRaisesRegex(
2697+
ValueError,
2698+
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
2699+
lambda: model_builder.optimize(
2700+
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2701+
),
2702+
)
2703+
26702704
@patch.object(ModelBuilder, "_prepare_for_mode")
26712705
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
26722706
def test_optimize_for_hf_with_custom_s3_path(

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):
261261

262262

263263
@pytest.mark.parametrize(
264-
"quantization_config, compilation_config, expected_config, expected_env",
264+
"quantization_config, compilation_config, sharding_config, expected_config, expected_env",
265265
[
266266
(
267267
None,
@@ -270,6 +270,7 @@ def test_is_s3_uri(s3_uri, expected):
270270
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
271271
}
272272
},
273+
None,
273274
{
274275
"ModelCompilationConfig": {
275276
"OverrideEnvironment": {
@@ -288,6 +289,7 @@ def test_is_s3_uri(s3_uri, expected):
288289
}
289290
},
290291
None,
292+
None,
291293
{
292294
"ModelQuantizationConfig": {
293295
"OverrideEnvironment": {
@@ -299,7 +301,26 @@ def test_is_s3_uri(s3_uri, expected):
299301
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
300302
},
301303
),
302-
(None, None, None, None),
304+
(
305+
None,
306+
None,
307+
{
308+
"OverrideEnvironment": {
309+
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
310+
}
311+
},
312+
{
313+
"ModelShardingConfig": {
314+
"OverrideEnvironment": {
315+
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
316+
}
317+
},
318+
},
319+
{
320+
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
321+
},
322+
),
323+
(None, None, None, None, None),
303324
],
304325
)
305326
def test_extract_optimization_config_and_env(

0 commit comments

Comments
 (0)