File tree 2 files changed +3
-27
lines changed
2 files changed +3
-27
lines changed Original file line number Diff line number Diff line change 22
22
import tempfile
23
23
from collections import namedtuple
24
24
from typing import Optional , Union , Dict
25
- import yaml
26
25
27
26
import sagemaker .image_uris
28
27
from sagemaker .session_settings import SessionSettings
@@ -248,7 +247,7 @@ def parse_mp_parameters(params):
248
247
249
248
Raises:
250
249
ValueError: if params is not a string or a dict, or
251
- the config file cannot be parsed as json or yaml .
250
+ the config file cannot be parsed as json.
252
251
"""
253
252
parsed = None
254
253
if isinstance (params , dict ):
@@ -258,19 +257,15 @@ def parse_mp_parameters(params):
258
257
with open (params , "r" ) as fp :
259
258
parsed = json .load (fp )
260
259
except json .decoder .JSONDecodeError :
261
- try :
262
- with open (params , "r" ) as fp :
263
- parsed = yaml .load (fp )
264
- except yaml .YAMLError :
265
- pass
260
+ pass
266
261
else :
267
262
raise ValueError (
268
263
f"Expected a string path to an existing modelparallel config, or a dictionary. "
269
264
f"Received: { params } ."
270
265
)
271
266
272
267
if parsed is None :
273
- raise ValueError (f"Cannot parse { params } as a json or yaml file." )
268
+ raise ValueError (f"Cannot parse { params } as a json file." )
274
269
275
270
return parsed
276
271
Original file line number Diff line number Diff line change 18
18
import tarfile
19
19
from contextlib import contextmanager
20
20
from itertools import product
21
- import yaml
22
21
23
22
import pytest
24
23
@@ -226,24 +225,6 @@ def test_parse_mp_parameters_input_str_json():
226
225
os .remove (json_file_path )
227
226
228
227
229
- def test_parse_mp_parameters_input_str_yaml ():
230
- mp_parameters = {
231
- "partitions" : 1 ,
232
- "tensor_parallel_degree" : 2 ,
233
- "microbatches" : 1 ,
234
- "optimize" : "speed" ,
235
- "pipeline" : "interleaved" ,
236
- "ddp" : 1 ,
237
- "auto_partition" : False ,
238
- "default_partition" : 0 ,
239
- }
240
- yaml_file_path = "./params.yaml"
241
- with open (yaml_file_path , "x" ) as fp :
242
- yaml .dump (mp_parameters , fp )
243
- assert mp_parameters == fw_utils .parse_mp_parameters (yaml_file_path )
244
- os .remove (yaml_file_path )
245
-
246
-
247
228
def test_parse_mp_parameters_input_not_exit ():
248
229
with pytest .raises (ValueError ):
249
230
fw_utils .parse_mp_parameters (" !@#$%^&*()path probably in not there.!@#$%^&*()" )
You can’t perform that action at this time.
0 commit comments