Skip to content

Commit 2e017c4

Browse files
committed
fix: hyperparameter validation error messages
1 parent c51d9bd commit 2e017c4

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

src/sagemaker/jumpstart/validators.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _validate_hyperparameter(
4949

5050
if len(hyperparameter_spec) > 1:
5151
raise JumpStartHyperparametersError(
52-
f"Unable to perform validation -- found multiple hyperparameter "
52+
"Unable to perform validation -- found multiple hyperparameter "
5353
f"'{hyperparameter_name}' in model specs."
5454
)
5555

@@ -76,35 +76,35 @@ def _validate_hyperparameter(
7676
if hyperparameter_value not in hyperparameter_spec.options:
7777
raise JumpStartHyperparametersError(
7878
f"Hyperparameter '{hyperparameter_name}' must have one of the following "
79-
f"values: {', '.join(hyperparameter_spec.options)}"
79+
f"values: {', '.join(hyperparameter_spec.options)}."
8080
)
8181

8282
if hasattr(hyperparameter_spec, "min"):
8383
if len(hyperparameter_value) < hyperparameter_spec.min:
8484
raise JumpStartHyperparametersError(
8585
f"Hyperparameter '{hyperparameter_name}' must have length no less than "
86-
f"{hyperparameter_spec.min}"
86+
f"{hyperparameter_spec.min}."
8787
)
8888

8989
if hasattr(hyperparameter_spec, "exclusive_min"):
9090
if len(hyperparameter_value) <= hyperparameter_spec.exclusive_min:
9191
raise JumpStartHyperparametersError(
9292
f"Hyperparameter '{hyperparameter_name}' must have length greater than "
93-
f"{hyperparameter_spec.exclusive_min}"
93+
f"{hyperparameter_spec.exclusive_min}."
9494
)
9595

9696
if hasattr(hyperparameter_spec, "max"):
9797
if len(hyperparameter_value) > hyperparameter_spec.max:
9898
raise JumpStartHyperparametersError(
9999
f"Hyperparameter '{hyperparameter_name}' must have length no greater than "
100-
f"{hyperparameter_spec.max}"
100+
f"{hyperparameter_spec.max}."
101101
)
102102

103103
if hasattr(hyperparameter_spec, "exclusive_max"):
104104
if len(hyperparameter_value) >= hyperparameter_spec.exclusive_max:
105105
raise JumpStartHyperparametersError(
106106
f"Hyperparameter '{hyperparameter_name}' must have length less than "
107-
f"{hyperparameter_spec.exclusive_max}"
107+
f"{hyperparameter_spec.exclusive_max}."
108108
)
109109

110110
# validate numeric types
@@ -125,35 +125,35 @@ def _validate_hyperparameter(
125125
if not hyperparameter_value_str[start_index:].isdigit():
126126
raise JumpStartHyperparametersError(
127127
f"Hyperparameter '{hyperparameter_name}' must be integer type "
128-
"('{hyperparameter_value}')."
128+
f"('{hyperparameter_value}')."
129129
)
130130

131131
if hasattr(hyperparameter_spec, "min"):
132132
if numeric_hyperparam_value < hyperparameter_spec.min:
133133
raise JumpStartHyperparametersError(
134134
f"Hyperparameter '{hyperparameter_name}' can be no less than "
135-
"{hyperparameter_spec.min}."
135+
f"{hyperparameter_spec.min}."
136136
)
137137

138138
if hasattr(hyperparameter_spec, "max"):
139139
if numeric_hyperparam_value > hyperparameter_spec.max:
140140
raise JumpStartHyperparametersError(
141141
f"Hyperparameter '{hyperparameter_name}' can be no greater than "
142-
"{hyperparameter_spec.max}."
142+
f"{hyperparameter_spec.max}."
143143
)
144144

145145
if hasattr(hyperparameter_spec, "exclusive_min"):
146146
if numeric_hyperparam_value <= hyperparameter_spec.exclusive_min:
147147
raise JumpStartHyperparametersError(
148148
f"Hyperparameter '{hyperparameter_name}' must be greater than "
149-
"{hyperparameter_spec.exclusive_min}."
149+
f"{hyperparameter_spec.exclusive_min}."
150150
)
151151

152152
if hasattr(hyperparameter_spec, "exclusive_max"):
153153
if numeric_hyperparam_value >= hyperparameter_spec.exclusive_max:
154154
raise JumpStartHyperparametersError(
155155
f"Hyperparameter '{hyperparameter_name}' must be less than "
156-
"{hyperparameter_spec.exclusive_max}."
156+
f"{hyperparameter_spec.exclusive_max}."
157157
)
158158

159159

tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py

+51-16
Original file line numberDiff line numberDiff line change
@@ -147,49 +147,54 @@ def add_options_to_hyperparameter(*largs, **kwargs):
147147
)
148148

149149
hyperparameter_to_test["batch-size"] = "0"
150-
with pytest.raises(JumpStartHyperparametersError):
150+
with pytest.raises(JumpStartHyperparametersError) as e:
151151
hyperparameters.validate(
152152
region=region,
153153
model_id=model_id,
154154
model_version=model_version,
155155
hyperparameters=hyperparameter_to_test,
156156
)
157+
assert str(e.value) == ("Hyperparameter 'batch-size' " "can be no less than 1.")
157158

158159
hyperparameter_to_test["batch-size"] = "-1"
159-
with pytest.raises(JumpStartHyperparametersError):
160+
with pytest.raises(JumpStartHyperparametersError) as e:
160161
hyperparameters.validate(
161162
region=region,
162163
model_id=model_id,
163164
model_version=model_version,
164165
hyperparameters=hyperparameter_to_test,
165166
)
167+
assert str(e.value) == ("Hyperparameter 'batch-size' can be no " "less than 1.")
166168

167169
hyperparameter_to_test["batch-size"] = "-1.5"
168-
with pytest.raises(JumpStartHyperparametersError):
170+
with pytest.raises(JumpStartHyperparametersError) as e:
169171
hyperparameters.validate(
170172
region=region,
171173
model_id=model_id,
172174
model_version=model_version,
173175
hyperparameters=hyperparameter_to_test,
174176
)
177+
assert str(e.value) == ("Hyperparameter 'batch-size' must be " "integer type ('-1.5').")
175178

176179
hyperparameter_to_test["batch-size"] = "1.5"
177-
with pytest.raises(JumpStartHyperparametersError):
180+
with pytest.raises(JumpStartHyperparametersError) as e:
178181
hyperparameters.validate(
179182
region=region,
180183
model_id=model_id,
181184
model_version=model_version,
182185
hyperparameters=hyperparameter_to_test,
183186
)
187+
assert str(e.value) == ("Hyperparameter 'batch-size' must be integer " "type ('1.5').")
184188

185189
hyperparameter_to_test["batch-size"] = "99999"
186-
with pytest.raises(JumpStartHyperparametersError):
190+
with pytest.raises(JumpStartHyperparametersError) as e:
187191
hyperparameters.validate(
188192
region=region,
189193
model_id=model_id,
190194
model_version=model_version,
191195
hyperparameters=hyperparameter_to_test,
192196
)
197+
assert str(e.value) == ("Hyperparameter 'batch-size' can be no greater " "than 1024.")
193198

194199
hyperparameter_to_test["batch-size"] = 5
195200
hyperparameters.validate(
@@ -210,13 +215,17 @@ def add_options_to_hyperparameter(*largs, **kwargs):
210215
)
211216
for val in [None, "", 5, "Truesday", "Falsehood"]:
212217
hyperparameter_to_test["test_bool_param"] = val
213-
with pytest.raises(JumpStartHyperparametersError):
218+
with pytest.raises(JumpStartHyperparametersError) as e:
214219
hyperparameters.validate(
215220
region=region,
216221
model_id=model_id,
217222
model_version=model_version,
218223
hyperparameters=hyperparameter_to_test,
219224
)
225+
assert str(e.value) == (
226+
"Expecting boolean valued hyperparameter, " f"but got '{str(val)}'."
227+
)
228+
220229
hyperparameter_to_test["test_bool_param"] = original_bool_val
221230

222231
original_exclusive_min_val = hyperparameter_to_test["test_exclusive_min_param"]
@@ -230,13 +239,16 @@ def add_options_to_hyperparameter(*largs, **kwargs):
230239
)
231240
for val in [1, 1 - 1e-99, -99]:
232241
hyperparameter_to_test["test_exclusive_min_param"] = val
233-
with pytest.raises(JumpStartHyperparametersError):
242+
with pytest.raises(JumpStartHyperparametersError) as e:
234243
hyperparameters.validate(
235244
region=region,
236245
model_id=model_id,
237246
model_version=model_version,
238247
hyperparameters=hyperparameter_to_test,
239248
)
249+
assert str(e.value) == (
250+
"Hyperparameter 'test_exclusive_min_param' must " "be greater than 1."
251+
)
240252
hyperparameter_to_test["test_exclusive_min_param"] = original_exclusive_min_val
241253

242254
original_exclusive_max_val = hyperparameter_to_test["test_exclusive_max_param"]
@@ -250,13 +262,15 @@ def add_options_to_hyperparameter(*largs, **kwargs):
250262
)
251263
for val in [4, 5, 99]:
252264
hyperparameter_to_test["test_exclusive_max_param"] = val
253-
with pytest.raises(JumpStartHyperparametersError):
265+
with pytest.raises(JumpStartHyperparametersError) as e:
254266
hyperparameters.validate(
255267
region=region,
256268
model_id=model_id,
257269
model_version=model_version,
258270
hyperparameters=hyperparameter_to_test,
259271
)
272+
assert str(e.value) == "Hyperparameter 'test_exclusive_max_param' must be less than 4."
273+
260274
hyperparameter_to_test["test_exclusive_max_param"] = original_exclusive_max_val
261275

262276
original_exclusive_max_text_val = hyperparameter_to_test["test_exclusive_max_param_text"]
@@ -270,13 +284,17 @@ def add_options_to_hyperparameter(*largs, **kwargs):
270284
)
271285
for val in ["123456", "123456789"]:
272286
hyperparameter_to_test["test_exclusive_max_param_text"] = val
273-
with pytest.raises(JumpStartHyperparametersError):
287+
with pytest.raises(JumpStartHyperparametersError) as e:
274288
hyperparameters.validate(
275289
region=region,
276290
model_id=model_id,
277291
model_version=model_version,
278292
hyperparameters=hyperparameter_to_test,
279293
)
294+
assert (
295+
str(e.value)
296+
== "Hyperparameter 'test_exclusive_max_param_text' must have length less than 6."
297+
)
280298
hyperparameter_to_test["test_exclusive_max_param_text"] = original_exclusive_max_text_val
281299

282300
original_max_text_val = hyperparameter_to_test["test_max_param_text"]
@@ -290,13 +308,17 @@ def add_options_to_hyperparameter(*largs, **kwargs):
290308
)
291309
for val in ["1234567", "123456789"]:
292310
hyperparameter_to_test["test_max_param_text"] = val
293-
with pytest.raises(JumpStartHyperparametersError):
311+
with pytest.raises(JumpStartHyperparametersError) as e:
294312
hyperparameters.validate(
295313
region=region,
296314
model_id=model_id,
297315
model_version=model_version,
298316
hyperparameters=hyperparameter_to_test,
299317
)
318+
assert (
319+
str(e.value)
320+
== "Hyperparameter 'test_max_param_text' must have length no greater than 6."
321+
)
300322
hyperparameter_to_test["test_max_param_text"] = original_max_text_val
301323

302324
original_exclusive_min_text_val = hyperparameter_to_test["test_exclusive_min_param_text"]
@@ -310,13 +332,16 @@ def add_options_to_hyperparameter(*largs, **kwargs):
310332
)
311333
for val in ["1", "d", ""]:
312334
hyperparameter_to_test["test_exclusive_min_param_text"] = val
313-
with pytest.raises(JumpStartHyperparametersError):
335+
with pytest.raises(JumpStartHyperparametersError) as e:
314336
hyperparameters.validate(
315337
region=region,
316338
model_id=model_id,
317339
model_version=model_version,
318340
hyperparameters=hyperparameter_to_test,
319341
)
342+
assert str(e.value) == (
343+
"Hyperparameter 'test_exclusive_min_param_text' must have length greater " "than 1."
344+
)
320345
hyperparameter_to_test["test_exclusive_min_param_text"] = original_exclusive_min_text_val
321346

322347
original_min_text_val = hyperparameter_to_test["test_min_param_text"]
@@ -330,24 +355,31 @@ def add_options_to_hyperparameter(*largs, **kwargs):
330355
)
331356
for val in [""]:
332357
hyperparameter_to_test["test_min_param_text"] = val
333-
with pytest.raises(JumpStartHyperparametersError):
358+
with pytest.raises(JumpStartHyperparametersError) as e:
334359
hyperparameters.validate(
335360
region=region,
336361
model_id=model_id,
337362
model_version=model_version,
338363
hyperparameters=hyperparameter_to_test,
339364
)
365+
assert str(e.value) == (
366+
"Hyperparameter 'test_min_param_text' " "must have length no less than 1."
367+
)
340368
hyperparameter_to_test["test_min_param_text"] = original_min_text_val
341369

342370
del hyperparameter_to_test["batch-size"]
343371
hyperparameter_to_test["penalty"] = "blah"
344-
with pytest.raises(JumpStartHyperparametersError):
372+
with pytest.raises(JumpStartHyperparametersError) as e:
345373
hyperparameters.validate(
346374
region=region,
347375
model_id=model_id,
348376
model_version=model_version,
349377
hyperparameters=hyperparameter_to_test,
350378
)
379+
assert str(e.value) == (
380+
"Hyperparameter 'penalty' must have one of the following values: l1, l2, elasticnet,"
381+
" none."
382+
)
351383

352384
hyperparameter_to_test["penalty"] = "elasticnet"
353385
hyperparameters.validate(
@@ -411,14 +443,15 @@ def add_options_to_hyperparameter(*largs, **kwargs):
411443
)
412444

413445
del hyperparameter_to_test["adam-learning-rate"]
414-
with pytest.raises(JumpStartHyperparametersError):
446+
with pytest.raises(JumpStartHyperparametersError) as e:
415447
hyperparameters.validate(
416448
region=region,
417449
model_id=model_id,
418450
model_version=model_version,
419451
hyperparameters=hyperparameter_to_test,
420452
validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM,
421453
)
454+
assert str(e.value) == "Cannot find algorithm hyperparameter for 'adam-learning-rate'."
422455

423456

424457
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -454,28 +487,30 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs):
454487

455488
del hyperparameter_to_test["sagemaker_submit_directory"]
456489

457-
with pytest.raises(JumpStartHyperparametersError):
490+
with pytest.raises(JumpStartHyperparametersError) as e:
458491
hyperparameters.validate(
459492
region=region,
460493
model_id=model_id,
461494
model_version=model_version,
462495
hyperparameters=hyperparameter_to_test,
463496
validation_mode=HyperparameterValidationMode.VALIDATE_ALL,
464497
)
498+
assert str(e.value) == "Cannot find hyperparameter for 'sagemaker_submit_directory'."
465499

466500
hyperparameter_to_test[
467501
"sagemaker_submit_directory"
468502
] = "/opt/ml/input/data/code/sourcedir.tar.gz"
469503
del hyperparameter_to_test["epochs"]
470504

471-
with pytest.raises(JumpStartHyperparametersError):
505+
with pytest.raises(JumpStartHyperparametersError) as e:
472506
hyperparameters.validate(
473507
region=region,
474508
model_id=model_id,
475509
model_version=model_version,
476510
hyperparameters=hyperparameter_to_test,
477511
validation_mode=HyperparameterValidationMode.VALIDATE_ALL,
478512
)
513+
assert str(e.value) == "Cannot find hyperparameter for 'epochs'."
479514

480515
hyperparameter_to_test["epochs"] = "3"
481516

0 commit comments

Comments
 (0)