Skip to content

Commit 2438a3f

Browse files
xiongz945beniericnargokulpintaoz-awspravali96
committed
Skip JS model mapping with env vars or image URI provided (#1599)
* Base model trainer (#1521) * Base model trainer * flake8 * add testing notebook * add param validation & set defaults * Implement simple train method * feature: support script mode with local train.sh (#1523) * feature: support script mode with local train.sh * Stop tracking train.sh and add it to .gitignore * update message * make dir if not exist * fix docs * fix: docstyle * Address comments * fix hyperparams * Revert pydantic custom error * pylint * Image Spec refactoring and updates (#1525) * Image Spec refactoring and updates * Unit tests and update function for Image Spec * Fix hugging face test * Fix Tests * Add unit tests for ModelTrainer (#1527) * Add unit tests for ModelTrainer * Flake8 * format * Add example notebook (#1528) * Add testing notebook * format * use smaller data * remove large dataset * update * pylint * flake8 * ignore docstyle in directories with test * format * format * Add enviornment variable bootstrapping script (#1530) * Add enviornment variables scripts * format * fix comment * add docstrings * fix comment * feature: add utility function to capture local snapshot (#1524) * local snapshot * Update pip list command * Remove function calls * Address comments * Address comments * Support intelligent parameters (#1540) * Support intelligent parameters * fix codestyle * Revert Image Spec (#1541) * Cleanup ModelTrainer (#1542) * General image builder (#1546) * General image builder * General image builder * Fix codestyle * Fix codestyle * Move location * Add warnings * Add integ tests * Fix integ test * Fix integ test * Fix region error * Add region * Latest Container Image (#1545) * Latest Container Image * Test Fixes * Parameterized tests and some logic updates * Test fixes * Move to Image URI * Fixes for unit test * Fixes for unit test * Fix codestyle error checks * Cleanup ModelTrainer code (#1552) * feat: add pre-processing and post-processing logic to inference_spec (#1560) * add pre-processing and post-processing logic to inference_spec * fix format * make accept_type and content_type optional * remove accept_type and content_type from pre/post processing * correct typo * Add Distributed Training Support Model Trainer (#1536) * Add path to set Additional Settings in ModelTrainer (#1555) * feature: support HuggingFace models with JumpStart configs * Update bucket name for the model mapping * Mask Sensitive Env Logs in Container (#1568) * Fix unit test * Fix bug in script mode setup ModelTrainer (#1575) * Save mapping as attribute * Fix style issues * Fix style issues * Fix: bypass jumpstart mapping when not in endpoint mode * Skip JS model mapping with env vars or image URI provided * Revert "Merge branch 'aws:master' into dev-morpheus" This reverts commit 26a0b0bb37e0343b3287f5c5c484df22726fc858, reversing changes made to d19d4e178442be4b6e1d07d55498dd76dfac50f0. * Merge branch 'aws:master' into dev-morpheus This reverts commit 076442bd83e5ca977bf5b6ce1b716474d2794feb. * Rebase on master-morpheus * Fix unit test description * Fix TEI integ test * Fix style issue * Fix style issues * Fix schema builder integ tests * Fix TEI integ test * Fix code style issue --------- Co-authored-by: Erick Benitez-Ramos <[email protected]> Co-authored-by: Gokul Anantha Narayanan <[email protected]> Co-authored-by: pintaoz-aws <[email protected]> Co-authored-by: Pravali Uppugunduri <[email protected]> Co-authored-by: Xiong Zeng <[email protected]> Co-authored-by: Gary Wang <[email protected]>
1 parent aa2e62d commit 2438a3f

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

src/sagemaker/serve/builder/model_builder.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,6 @@ def build( # pylint: disable=R0911
872872
Returns:
873873
Type[Model]: A deployable ``Model`` object.
874874
"""
875-
from sagemaker.modules.train.model_trainer import ModelTrainer
876875

877876
self.modes = dict()
878877

@@ -1728,10 +1727,24 @@ def _use_jumpstart_equivalent(self):
17281727
17291728
Replace it with the equivalent if there's one
17301729
"""
1730+
# Do not use the equivalent JS model if image_uri or env_vars is provided
1731+
if self.image_uri or self.env_vars:
1732+
return False
17311733
if not hasattr(self, "_has_jumpstart_equivalent"):
17321734
self._jumpstart_mapping = self._retrieve_hugging_face_model_mapping()
17331735
self._has_jumpstart_equivalent = self.model in self._jumpstart_mapping
17341736
if self._has_jumpstart_equivalent:
1737+
# Use schema builder from HF model metadata
1738+
if not self.schema_builder:
1739+
model_task = None
1740+
if self.model_metadata:
1741+
model_task = self.model_metadata.get("HF_TASK")
1742+
hf_model_md = get_huggingface_model_metadata(self.model)
1743+
if not model_task:
1744+
model_task = hf_model_md.get("pipeline_tag")
1745+
if model_task:
1746+
self._hf_schema_builder_init(model_task)
1747+
17351748
huggingface_model_id = self.model
17361749
jumpstart_model_id = self._jumpstart_mapping[huggingface_model_id]["jumpstart-model-id"]
17371750
self.model = jumpstart_model_id
@@ -1743,7 +1756,7 @@ def _use_jumpstart_equivalent(self):
17431756
"artifact S3 URI and compare them."
17441757
)
17451758
logger.warning( # pylint: disable=logging-fstring-interpolation
1746-
"Please note that for this model we are using the JumpStart's"
1759+
"Please note that for this model we are using the JumpStart's "
17471760
f'local copy "{jumpstart_model_id}" '
17481761
f'of the HuggingFace model "{huggingface_model_id}" you chose. '
17491762
"We strive to keep our local copy synced with the HF model hub closely. "

tests/integ/sagemaker/serve/test_schema_builder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def test_model_builder_happy_path_with_task_provided_local_schema_mode(
111111
if container_startup_timeout:
112112
predictor = model.deploy(
113113
role=role_arn,
114-
instance_count=1,
114+
initial_instance_count=1,
115115
instance_type=instance_type_provided,
116116
container_startup_health_check_timeout=container_startup_timeout,
117117
)
118118
else:
119119
predictor = model.deploy(
120-
role=role_arn, instance_count=1, instance_type=instance_type_provided
120+
role=role_arn, initial_instance_count=1, instance_type=instance_type_provided
121121
)
122122

123123
predicted_outputs = predictor.predict(inputs)
@@ -181,7 +181,7 @@ def test_model_builder_happy_path_with_task_provided_remote_schema_mode(
181181

182182
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
183183
predictor = model.deploy(
184-
role=role_arn, instance_count=1, instance_type=instance_type_provided
184+
role=role_arn, initial_instance_count=1, instance_type=instance_type_provided
185185
)
186186

187187
predicted_outputs = predictor.predict(inputs)

tests/integ/sagemaker/serve/test_serve_tei.py

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def model_builder_model_schema_builder():
4444
model_path=HF_DIR,
4545
model="BAAI/bge-m3",
4646
schema_builder=SchemaBuilder(sample_input, loaded_response),
47+
env_vars={
48+
# Add this to bypass JumpStart model mapping
49+
"HF_MODEL_ID": "BAAI/bge-m3"
50+
},
4751
)
4852

4953

0 commit comments

Comments
 (0)