Skip to content

Commit 741d0a6

Browse files
author
Ashish Gupta
committed
fix sharded model flag
1 parent 65f4cc3 commit 741d0a6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,9 @@ def _model_builder_optimize_wrapper(
12931293
max_runtime_in_sec=max_runtime_in_sec,
12941294
)
12951295

1296+
if sharding_config:
1297+
self.pysdk_model._is_sharded_model = True
1298+
12961299
if input_args:
12971300
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
12981301
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
@@ -1302,9 +1305,6 @@ def _model_builder_optimize_wrapper(
13021305
if not speculative_decoding_config:
13031306
self.pysdk_model.remove_tag_with_key(Tag.SPECULATIVE_DRAFT_MODEL_PROVIDER)
13041307

1305-
if sharding_config:
1306-
self.pysdk_model._is_sharded_model = True
1307-
13081308
return self.pysdk_model
13091309

13101310
def _optimize_for_hf(

0 commit comments

Comments
 (0)