Skip to content

Commit 0785831

Browse files
author
Joseph Zhang
committed
Disable network isolation if using sharded models.
1 parent 741d0a6 commit 0785831

File tree

7 files changed

+47
-14
lines changed

7 files changed

+47
-14
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,10 +1601,18 @@ def deploy(
16011601
self._base_name = "-".join((self._base_name, compiled_model_suffix))
16021602

16031603
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1604-
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1605-
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
1604+
logging.warning(
1605+
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1606+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
1607+
)
16061608
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
16071609

1610+
if self._is_sharded_model and self._enable_network_isolation:
1611+
raise ValueError(
1612+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1613+
"Loading of model requires network access."
1614+
)
1615+
16081616
# Support multiple models on same endpoint
16091617
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
16101618
if endpoint_name:

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,14 @@ def _optimize_for_jumpstart(
795795
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
796796
if optimization_env_vars:
797797
self.pysdk_model.env.update(optimization_env_vars)
798+
799+
if sharding_config and self.pysdk_model._enable_network_isolation:
800+
logger.warning(
801+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
802+
"Loading of model requires network access. Setting it to False."
803+
)
804+
self.pysdk_model._enable_network_isolation = False
805+
798806
if quantization_config or sharding_config or is_compilation:
799807
return create_optimization_job_args
800808
return None

src/sagemaker/serve/builder/model_builder.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,11 +1245,25 @@ def _model_builder_optimize_wrapper(
12451245
if quantization_config and compilation_config:
12461246
raise ValueError("Quantization config and compilation config are mutually exclusive.")
12471247

1248-
if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
1249-
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")
1248+
if sharding_config and (
1249+
quantization_config or compilation_config or speculative_decoding_config
1250+
):
1251+
raise ValueError(
1252+
"Sharding config is mutually exclusive and cannot be combined with any "
1253+
"other optimization."
1254+
)
12501255

1251-
if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):
1252-
raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")
1256+
if sharding_config and (
1257+
(env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars)
1258+
or (
1259+
sharding_config.get("OverrideEnvironment")
1260+
and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"]
1261+
)
1262+
):
1263+
raise ValueError(
1264+
"OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with "
1265+
"sharding config."
1266+
)
12531267

12541268
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12551269
self.instance_type = instance_type or self.instance_type

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,9 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
259259

260260

261261
def _extract_optimization_config_and_env(
262-
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
263-
sharding_config: Optional[Dict] = None
262+
quantization_config: Optional[Dict] = None,
263+
compilation_config: Optional[Dict] = None,
264+
sharding_config: Optional[Dict] = None,
264265
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
265266
"""Extracts optimization config and environment variables.
266267
@@ -282,9 +283,7 @@ def _extract_optimization_config_and_env(
282283
"OverrideEnvironment"
283284
)
284285
if sharding_config:
285-
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
286-
"OverrideEnvironment"
287-
)
286+
return {"ModelShardingConfig": sharding_config}, sharding_config.get("OverrideEnvironment")
288287
return None, None
289288

290289

tests/unit/sagemaker/model/test_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
958958
sagemaker_session.endpoint_in_service_or_not.reset_mock()
959959
sagemaker_session.create_model.reset_mock()
960960

961+
961962
@patch("sagemaker.utils.repack_model")
962963
@patch("sagemaker.fw_utils.tar_and_upload_dir")
963964
def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
@@ -967,7 +968,7 @@ def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
967968
HuggingFaceModel: {
968969
"pytorch_version": "1.7.1",
969970
"py_version": "py36",
970-
"transformers_version": "4.6.1"
971+
"transformers_version": "4.6.1",
971972
},
972973
}
973974

@@ -1007,6 +1008,7 @@ def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
10071008
sagemaker_session.endpoint_in_service_or_not.reset_mock()
10081009
sagemaker_session.create_model.reset_mock()
10091010

1011+
10101012
@patch("sagemaker.utils.repack_model")
10111013
def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
10121014

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2694,7 +2694,7 @@ def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
26942694

26952695
self.assertRaisesRegex(
26962696
ValueError,
2697-
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
2697+
"OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.",
26982698
lambda: model_builder.optimize(
26992699
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
27002700
),

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ def test_is_s3_uri(s3_uri, expected):
326326
def test_extract_optimization_config_and_env(
327327
quantization_config, compilation_config, sharding_config, expected_config, expected_env
328328
):
329-
assert _extract_optimization_config_and_env(quantization_config, compilation_config, sharding_config) == (
329+
assert _extract_optimization_config_and_env(
330+
quantization_config, compilation_config, sharding_config
331+
) == (
330332
expected_config,
331333
expected_env,
332334
)

0 commit comments

Comments
 (0)