Skip to content

Commit b17d332

Browse files
authored
Fix:invalid component error with new metadata (#4634)
* fix: invalid component name * tests * format * fix vulnerable model integ tests llama 2 * updated * fix: training dataset location
1 parent 72e0c97 commit b17d332

File tree

6 files changed

+23
-5
lines changed

6 files changed

+23
-5
lines changed

src/sagemaker/jumpstart/estimator.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,12 @@ def attach(
734734

735735
model_version = model_version or "*"
736736

737-
additional_kwargs = {"model_id": model_id, "model_version": model_version}
737+
additional_kwargs = {
738+
"model_id": model_id,
739+
"model_version": model_version,
740+
"tolerate_vulnerable_model": True, # model is already trained
741+
"tolerate_deprecated_model": True, # model is already trained
742+
}
738743

739744
model_specs = verify_model_region_and_return_specs(
740745
model_id=model_id,

src/sagemaker/jumpstart/types.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
10641064
Dictionary representation of the config component.
10651065
"""
10661066
for field in json_obj.keys():
1067-
if field not in self.__slots__:
1068-
raise ValueError(f"Invalid component field: {field}")
1069-
setattr(self, field, json_obj[field])
1067+
if field in self.__slots__:
1068+
setattr(self, field, json_obj[field])
10701069

10711070

10721071
class JumpStartMetadataConfig(JumpStartDataHolderType):

tests/integ/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4848
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
4949
("meta-textgeneration-llama-2-7b", "2.*"): ("training-datasets/sec_amazon/"),
5050
("meta-textgeneration-llama-2-7b", "3.*"): ("training-datasets/sec_amazon/"),
51+
("meta-textgeneration-llama-2-7b", "4.*"): ("training-datasets/sec_amazon/"),
5152
("meta-textgenerationneuron-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
5253
}
5354

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_gated_model_training_v1(setup):
140140
def test_gated_model_training_v2(setup):
141141

142142
model_id = "meta-textgeneration-llama-2-7b"
143-
model_version = "3.*" # model artifacts retrieved from jumpstart-private-cache-* buckets
143+
model_version = "4.*" # model artifacts retrieved from jumpstart-private-cache-* buckets
144144

145145
estimator = JumpStartEstimator(
146146
model_id=model_id,
@@ -150,6 +150,7 @@ def test_gated_model_training_v2(setup):
150150
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
151151
environment={"accept_eula": "true"},
152152
max_run=259200, # avoid exceeding resource limits
153+
tolerate_vulnerable_model=True, # tolerate old version of model
153154
)
154155

155156
# uses ml.g5.12xlarge instance

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,8 @@ def test_jumpstart_estimator_attach_eula_model(
10101010
"model_id": "gemma-model",
10111011
"model_version": "*",
10121012
"environment": {"accept_eula": "true"},
1013+
"tolerate_vulnerable_model": True,
1014+
"tolerate_deprecated_model": True,
10131015
},
10141016
)
10151017

@@ -1053,6 +1055,8 @@ def test_jumpstart_estimator_attach_no_model_id_happy_case(
10531055
additional_kwargs={
10541056
"model_id": "js-trainable-model-prepacked",
10551057
"model_version": "1.0.0",
1058+
"tolerate_vulnerable_model": True,
1059+
"tolerate_deprecated_model": True,
10561060
},
10571061
)
10581062

tests/unit/sagemaker/jumpstart/test_types.py

+8
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,14 @@ def test_inference_configs_parsing():
10521052
)
10531053
assert list(config.config_components.keys()) == ["neuron-inference"]
10541054

1055+
spec = {
1056+
**BASE_SPEC,
1057+
**INFERENCE_CONFIGS,
1058+
**INFERENCE_CONFIG_RANKINGS,
1059+
"unrecognized-field": "blah", # New fields in base metadata fields should be ignored
1060+
}
1061+
specs1 = JumpStartModelSpecs(spec)
1062+
10551063

10561064
def test_set_inference_configs():
10571065
spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}

0 commit comments

Comments
 (0)