Skip to content

Commit 246d560

Browse files
authored
fix: honor json serialization of HPs (#5164)
* fix: honor json serialization of HPs * test * fix
1 parent 67a3e5a commit 246d560

File tree

2 files changed

+8
-11
lines changed
  • src/sagemaker/modules/train/container_drivers/common
  • tests/unit/sagemaker/modules/train/container_drivers

2 files changed

+8
-11
lines changed

src/sagemaker/modules/train/container_drivers/common/utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def safe_deserialize(data: Any) -> Any:
124124
125125
This function handles the following cases:
126126
1. If `data` is not a string, it returns the input as-is.
127-
2. If `data` is a string and matches common boolean values ("true" or "false"),
128-
it returns the corresponding boolean value (True or False).
129127
3. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`.
130128
4. If `data` is a string but cannot be decoded as JSON, it returns the original string.
131129
@@ -134,13 +132,6 @@ def safe_deserialize(data: Any) -> Any:
134132
"""
135133
if not isinstance(data, str):
136134
return data
137-
138-
lower_data = data.lower()
139-
if lower_data in ["true"]:
140-
return True
141-
if lower_data in ["false"]:
142-
return False
143-
144135
try:
145136
return json.loads(data)
146137
except json.JSONDecodeError:

tests/unit/sagemaker/modules/train/container_drivers/test_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,14 @@ def test_safe_deserialize_not_a_string():
5959
def test_safe_deserialize_boolean_strings():
6060
assert safe_deserialize("true") is True
6161
assert safe_deserialize("false") is False
62-
assert safe_deserialize("True") is True
63-
assert safe_deserialize("False") is False
62+
63+
# The below are not valid JSON booleans
64+
assert safe_deserialize("True") == "True"
65+
assert safe_deserialize("False") == "False"
66+
assert safe_deserialize("TRUE") == "TRUE"
67+
assert safe_deserialize("FALSE") == "FALSE"
68+
assert safe_deserialize("tRuE") == "tRuE"
69+
assert safe_deserialize("fAlSe") == "fAlSe"
6470

6571

6672
def test_safe_deserialize_valid_json_string():

0 commit comments

Comments
 (0)