Skip to content

Commit e40fad7

Browse files
author
Ashish Gupta
committed
add unit tests
1 parent 5833143 commit e40fad7

File tree

3 files changed

+109
-4
lines changed

3 files changed

+109
-4
lines changed

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,57 @@ def test_optimize_quantize_for_jumpstart(
11981198

11991199
self.assertIsNotNone(out_put)
12001200

1201+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
1202+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
1203+
def test_optimize_sharding_for_jumpstart(
1204+
self,
1205+
mock_serve_settings,
1206+
mock_telemetry,
1207+
):
1208+
mock_sagemaker_session = Mock()
1209+
1210+
mock_pysdk_model = Mock()
1211+
mock_pysdk_model.env = {"SAGEMAKER_ENV": "1"}
1212+
mock_pysdk_model.model_data = mock_model_data
1213+
mock_pysdk_model.image_uri = mock_tgi_image_uri
1214+
mock_pysdk_model.list_deployment_configs.return_value = DEPLOYMENT_CONFIGS
1215+
mock_pysdk_model.deployment_config = DEPLOYMENT_CONFIGS[0]
1216+
1217+
sample_input = {
1218+
"inputs": "The diamondback terrapin or simply terrapin is a species "
1219+
"of turtle native to the brackish coastal tidal marshes of the",
1220+
"parameters": {"max_new_tokens": 1024},
1221+
}
1222+
sample_output = [
1223+
{
1224+
"generated_text": "The diamondback terrapin or simply terrapin is a "
1225+
"species of turtle native to the brackish coastal "
1226+
"tidal marshes of the east coast."
1227+
}
1228+
]
1229+
1230+
model_builder = ModelBuilder(
1231+
model="meta-textgeneration-llama-3-70b",
1232+
schema_builder=SchemaBuilder(sample_input, sample_output),
1233+
sagemaker_session=mock_sagemaker_session,
1234+
)
1235+
1236+
model_builder.pysdk_model = mock_pysdk_model
1237+
1238+
out_put = model_builder._optimize_for_jumpstart(
1239+
accept_eula=True,
1240+
sharding_config={
1241+
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
1242+
},
1243+
env_vars={
1244+
"OPTION_TENSOR_PARALLEL_DEGREE": "1",
1245+
"OPTION_MAX_ROLLING_BATCH_SIZE": "2",
1246+
},
1247+
output_path="s3://bucket/code/",
1248+
)
1249+
1250+
self.assertIsNotNone(out_put)
1251+
12011252
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
12021253
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
12031254
@patch(

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2667,6 +2667,39 @@ def test_optimize_exclusive_args(self, mock_get_serve_setting):
26672667
),
26682668
)
26692669

2670+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2671+
def test_optimize_exclusive_sharding(self, mock_get_serve_setting):
2672+
mock_sagemaker_session = Mock()
2673+
model_builder = ModelBuilder(
2674+
model="meta-textgeneration-llama-3-70b",
2675+
sagemaker_session=mock_sagemaker_session,
2676+
)
2677+
2678+
self.assertRaisesRegex(
2679+
ValueError,
2680+
"Sharding config is mutually exclusive and cannot be combined with any other optimization.",
2681+
lambda: model_builder.optimize(
2682+
compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2683+
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2684+
),
2685+
)
2686+
2687+
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
2688+
def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting):
2689+
mock_sagemaker_session = Mock()
2690+
model_builder = ModelBuilder(
2691+
model="meta-textgeneration-llama-3-70b",
2692+
sagemaker_session=mock_sagemaker_session,
2693+
)
2694+
2695+
self.assertRaisesRegex(
2696+
ValueError,
2697+
"OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.",
2698+
lambda: model_builder.optimize(
2699+
sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}},
2700+
),
2701+
)
2702+
26702703
@patch.object(ModelBuilder, "_prepare_for_mode")
26712704
@patch.object(ModelBuilder, "_get_serve_setting", autospec=True)
26722705
def test_optimize_for_hf_with_custom_s3_path(

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):
261261

262262

263263
@pytest.mark.parametrize(
264-
"quantization_config, compilation_config, expected_config, expected_env",
264+
"quantization_config, compilation_config, sharding_config, expected_config, expected_env",
265265
[
266266
(
267267
None,
@@ -270,6 +270,7 @@ def test_is_s3_uri(s3_uri, expected):
270270
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
271271
}
272272
},
273+
None,
273274
{
274275
"ModelCompilationConfig": {
275276
"OverrideEnvironment": {
@@ -288,6 +289,7 @@ def test_is_s3_uri(s3_uri, expected):
288289
}
289290
},
290291
None,
292+
None,
291293
{
292294
"ModelQuantizationConfig": {
293295
"OverrideEnvironment": {
@@ -299,13 +301,32 @@ def test_is_s3_uri(s3_uri, expected):
299301
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
300302
},
301303
),
302-
(None, None, None, None),
304+
(
305+
None,
306+
None,
307+
{
308+
"OverrideEnvironment": {
309+
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
310+
}
311+
},
312+
{
313+
"ModelShardingConfig": {
314+
"OverrideEnvironment": {
315+
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
316+
}
317+
},
318+
},
319+
{
320+
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
321+
},
322+
),
323+
(None, None, None, None, None),
303324
],
304325
)
305326
def test_extract_optimization_config_and_env(
306-
quantization_config, compilation_config, expected_config, expected_env
327+
quantization_config, compilation_config, sharding_config, expected_config, expected_env
307328
):
308-
assert _extract_optimization_config_and_env(quantization_config, compilation_config) == (
329+
assert _extract_optimization_config_and_env(quantization_config, compilation_config, sharding_config) == (
309330
expected_config,
310331
expected_env,
311332
)

0 commit comments

Comments
 (0)