Skip to content

Commit e42af65

Browse files
committed
Add tests for hp contracts
1 parent 337850e commit e42af65

File tree

5 files changed

+99
-3
lines changed

5 files changed

+99
-3
lines changed

tests/data/modules/params_script/hyperparameters.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"dict": {
88
"string": "value",
99
"integer": 3,
10+
"float": 3.14,
1011
"list": [1, 2, 3],
1112
"dict": {"key": "value"},
1213
"boolean": true

tests/data/modules/params_script/hyperparameters.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ list:
99
dict:
1010
string: value
1111
integer: 3
12+
float: 3.14
1213
list:
1314
- 1
1415
- 2
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
omegaconf

tests/data/modules/params_script/train.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import argparse
1717
import json
1818
import os
19+
from typing import List, Dict, Any
20+
from dataclasses import dataclass
21+
from omegaconf import OmegaConf
1922

2023
EXPECTED_HYPERPARAMETERS = {
2124
"integer": 1,
@@ -26,6 +29,7 @@
2629
"dict": {
2730
"string": "value",
2831
"integer": 3,
32+
"float": 3.14,
2933
"list": [1, 2, 3],
3034
"dict": {"key": "value"},
3135
"boolean": True,
@@ -117,7 +121,7 @@ def main():
117121
assert isinstance(params["dict"], dict)
118122

119123
params = json.loads(os.environ["SM_TRAINING_ENV"])["hyperparameters"]
120-
print(params)
124+
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
121125
assert params["string"] == EXPECTED_HYPERPARAMETERS["string"]
122126
assert params["integer"] == EXPECTED_HYPERPARAMETERS["integer"]
123127
assert params["boolean"] == EXPECTED_HYPERPARAMETERS["boolean"]
@@ -132,9 +136,96 @@ def main():
132136
assert isinstance(params["float"], float)
133137
assert isinstance(params["list"], list)
134138
assert isinstance(params["dict"], dict)
135-
print(f"SM_TRAINING_ENV -> hyperparameters: {params}")
136139

137-
print("Test passed.")
140+
# Local JSON - DictConfig OmegaConf
141+
params = OmegaConf.load("hyperparameters.json")
142+
143+
print(f"Local hyperparameters.json: {params}")
144+
assert params.string == EXPECTED_HYPERPARAMETERS["string"]
145+
assert params.integer == EXPECTED_HYPERPARAMETERS["integer"]
146+
assert params.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
147+
assert params.float == EXPECTED_HYPERPARAMETERS["float"]
148+
assert params.list == EXPECTED_HYPERPARAMETERS["list"]
149+
assert params.dict == EXPECTED_HYPERPARAMETERS["dict"]
150+
assert params.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
151+
assert params.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
152+
assert params.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
153+
assert params.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
154+
assert params.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
155+
assert params.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
156+
157+
@dataclass
158+
class DictConfig:
159+
string: str
160+
integer: int
161+
boolean: bool
162+
float: float
163+
list: List[int]
164+
dict: Dict[str, Any]
165+
166+
@dataclass
167+
class HPConfig:
168+
string: str
169+
integer: int
170+
boolean: bool
171+
float: float
172+
list: List[int]
173+
dict: DictConfig
174+
175+
# Local JSON - Structured OmegaConf
176+
hp_config: HPConfig = OmegaConf.merge(
177+
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.json")
178+
)
179+
print(f"Local hyperparameters.json - Structured: {hp_config}")
180+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
181+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
182+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
183+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
184+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
185+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
186+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
187+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
188+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
189+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
190+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
191+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
192+
193+
# Local YAML - Structured OmegaConf
194+
hp_config: HPConfig = OmegaConf.merge(
195+
OmegaConf.structured(HPConfig), OmegaConf.load("hyperparameters.yaml")
196+
)
197+
print(f"Local hyperparameters.yaml - Structured: {hp_config}")
198+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
199+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
200+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
201+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
202+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
203+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
204+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
205+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
206+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
207+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
208+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
209+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
210+
print(f"hyperparameters.yaml -> hyperparameters: {hp_config}")
211+
212+
# HP Dict - Structured OmegaConf
213+
hp_dict = json.loads(os.environ["SM_HPS"])
214+
hp_config: HPConfig = OmegaConf.merge(OmegaConf.structured(HPConfig), OmegaConf.create(hp_dict))
215+
print(f"SM_HPS - Structured: {hp_config}")
216+
assert hp_config.string == EXPECTED_HYPERPARAMETERS["string"]
217+
assert hp_config.integer == EXPECTED_HYPERPARAMETERS["integer"]
218+
assert hp_config.boolean == EXPECTED_HYPERPARAMETERS["boolean"]
219+
assert hp_config.float == EXPECTED_HYPERPARAMETERS["float"]
220+
assert hp_config.list == EXPECTED_HYPERPARAMETERS["list"]
221+
assert hp_config.dict == EXPECTED_HYPERPARAMETERS["dict"]
222+
assert hp_config.dict.string == EXPECTED_HYPERPARAMETERS["dict"]["string"]
223+
assert hp_config.dict.integer == EXPECTED_HYPERPARAMETERS["dict"]["integer"]
224+
assert hp_config.dict.boolean == EXPECTED_HYPERPARAMETERS["dict"]["boolean"]
225+
assert hp_config.dict.float == EXPECTED_HYPERPARAMETERS["dict"]["float"]
226+
assert hp_config.dict.list == EXPECTED_HYPERPARAMETERS["dict"]["list"]
227+
assert hp_config.dict.dict == EXPECTED_HYPERPARAMETERS["dict"]["dict"]
228+
print(f"SM_HPS -> hyperparameters: {hp_config}")
138229

139230

140231
if __name__ == "__main__":

tests/integ/sagemaker/modules/train/test_model_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"dict": {
2929
"string": "value",
3030
"integer": 3,
31+
"float": 3.14,
3132
"list": [1, 2, 3],
3233
"dict": {"key": "value"},
3334
"boolean": True,
@@ -40,6 +41,7 @@
4041
def test_hp_contract_basic_py_script(modules_sagemaker_session):
4142
source_code = SourceCode(
4243
source_dir=f"{DATA_DIR}/modules/params_script",
44+
requirements="requirements.txt",
4345
entry_script="train.py",
4446
)
4547

0 commit comments

Comments
 (0)