16
16
import argparse
17
17
import json
18
18
import os
19
+ from typing import List , Dict , Any
20
+ from dataclasses import dataclass
21
+ from omegaconf import OmegaConf
19
22
20
23
EXPECTED_HYPERPARAMETERS = {
21
24
"integer" : 1 ,
26
29
"dict" : {
27
30
"string" : "value" ,
28
31
"integer" : 3 ,
32
+ "float" : 3.14 ,
29
33
"list" : [1 , 2 , 3 ],
30
34
"dict" : {"key" : "value" },
31
35
"boolean" : True ,
@@ -117,7 +121,7 @@ def main():
117
121
assert isinstance (params ["dict" ], dict )
118
122
119
123
params = json .loads (os .environ ["SM_TRAINING_ENV" ])["hyperparameters" ]
120
- print (params )
124
+ print (f"SM_TRAINING_ENV -> hyperparameters: { params } " )
121
125
assert params ["string" ] == EXPECTED_HYPERPARAMETERS ["string" ]
122
126
assert params ["integer" ] == EXPECTED_HYPERPARAMETERS ["integer" ]
123
127
assert params ["boolean" ] == EXPECTED_HYPERPARAMETERS ["boolean" ]
@@ -132,9 +136,96 @@ def main():
132
136
assert isinstance (params ["float" ], float )
133
137
assert isinstance (params ["list" ], list )
134
138
assert isinstance (params ["dict" ], dict )
135
- print (f"SM_TRAINING_ENV -> hyperparameters: { params } " )
136
139
137
- print ("Test passed." )
140
+ # Local JSON - DictConfig OmegaConf
141
+ params = OmegaConf .load ("hyperparameters.json" )
142
+
143
+ print (f"Local hyperparameters.json: { params } " )
144
+ assert params .string == EXPECTED_HYPERPARAMETERS ["string" ]
145
+ assert params .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
146
+ assert params .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
147
+ assert params .float == EXPECTED_HYPERPARAMETERS ["float" ]
148
+ assert params .list == EXPECTED_HYPERPARAMETERS ["list" ]
149
+ assert params .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
150
+ assert params .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
151
+ assert params .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
152
+ assert params .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
153
+ assert params .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
154
+ assert params .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
155
+ assert params .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
156
+
157
+ @dataclass
158
+ class DictConfig :
159
+ string : str
160
+ integer : int
161
+ boolean : bool
162
+ float : float
163
+ list : List [int ]
164
+ dict : Dict [str , Any ]
165
+
166
+ @dataclass
167
+ class HPConfig :
168
+ string : str
169
+ integer : int
170
+ boolean : bool
171
+ float : float
172
+ list : List [int ]
173
+ dict : DictConfig
174
+
175
+ # Local JSON - Structured OmegaConf
176
+ hp_config : HPConfig = OmegaConf .merge (
177
+ OmegaConf .structured (HPConfig ), OmegaConf .load ("hyperparameters.json" )
178
+ )
179
+ print (f"Local hyperparameters.json - Structured: { hp_config } " )
180
+ assert hp_config .string == EXPECTED_HYPERPARAMETERS ["string" ]
181
+ assert hp_config .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
182
+ assert hp_config .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
183
+ assert hp_config .float == EXPECTED_HYPERPARAMETERS ["float" ]
184
+ assert hp_config .list == EXPECTED_HYPERPARAMETERS ["list" ]
185
+ assert hp_config .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
186
+ assert hp_config .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
187
+ assert hp_config .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
188
+ assert hp_config .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
189
+ assert hp_config .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
190
+ assert hp_config .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
191
+ assert hp_config .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
192
+
193
+ # Local YAML - Structured OmegaConf
194
+ hp_config : HPConfig = OmegaConf .merge (
195
+ OmegaConf .structured (HPConfig ), OmegaConf .load ("hyperparameters.yaml" )
196
+ )
197
+ print (f"Local hyperparameters.yaml - Structured: { hp_config } " )
198
+ assert hp_config .string == EXPECTED_HYPERPARAMETERS ["string" ]
199
+ assert hp_config .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
200
+ assert hp_config .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
201
+ assert hp_config .float == EXPECTED_HYPERPARAMETERS ["float" ]
202
+ assert hp_config .list == EXPECTED_HYPERPARAMETERS ["list" ]
203
+ assert hp_config .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
204
+ assert hp_config .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
205
+ assert hp_config .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
206
+ assert hp_config .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
207
+ assert hp_config .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
208
+ assert hp_config .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
209
+ assert hp_config .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
210
+ print (f"hyperparameters.yaml -> hyperparameters: { hp_config } " )
211
+
212
+ # HP Dict - Structured OmegaConf
213
+ hp_dict = json .loads (os .environ ["SM_HPS" ])
214
+ hp_config : HPConfig = OmegaConf .merge (OmegaConf .structured (HPConfig ), OmegaConf .create (hp_dict ))
215
+ print (f"SM_HPS - Structured: { hp_config } " )
216
+ assert hp_config .string == EXPECTED_HYPERPARAMETERS ["string" ]
217
+ assert hp_config .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
218
+ assert hp_config .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
219
+ assert hp_config .float == EXPECTED_HYPERPARAMETERS ["float" ]
220
+ assert hp_config .list == EXPECTED_HYPERPARAMETERS ["list" ]
221
+ assert hp_config .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
222
+ assert hp_config .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
223
+ assert hp_config .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
224
+ assert hp_config .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
225
+ assert hp_config .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
226
+ assert hp_config .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
227
+ assert hp_config .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
228
+ print (f"SM_HPS -> hyperparameters: { hp_config } " )
138
229
139
230
140
231
if __name__ == "__main__" :
0 commit comments