File tree 2 files changed +3
-27
lines changed
2 files changed +3
-27
lines changed Original file line number Diff line number Diff line change 22
22
import tempfile
23
23
from collections import namedtuple
24
24
from typing import Optional , Union , Dict
25
- import yaml
26
25
27
26
import sagemaker .image_uris
28
27
from sagemaker .session_settings import SessionSettings
@@ -222,7 +221,7 @@ def parse_mp_parameters(params):
222
221
223
222
Raises:
224
223
ValueError: if params is not a string or a dict, or
225
- the config file cannot be parsed as json or yaml .
224
+ the config file cannot be parsed as json.
226
225
"""
227
226
parsed = None
228
227
if isinstance (params , dict ):
@@ -232,19 +231,15 @@ def parse_mp_parameters(params):
232
231
with open (params , "r" ) as fp :
233
232
parsed = json .load (fp )
234
233
except json .decoder .JSONDecodeError :
235
- try :
236
- with open (params , "r" ) as fp :
237
- parsed = yaml .load (fp )
238
- except yaml .YAMLError :
239
- pass
234
+ pass
240
235
else :
241
236
raise ValueError (
242
237
f"Expected a string path to an existing modelparallel config, or a dictionary. "
243
238
f"Received: { params } ."
244
239
)
245
240
246
241
if parsed is None :
247
- raise ValueError (f"Cannot parse { params } as a json or yaml file." )
242
+ raise ValueError (f"Cannot parse { params } as a json file." )
248
243
249
244
return parsed
250
245
Original file line number Diff line number Diff line change 18
18
import tarfile
19
19
from contextlib import contextmanager
20
20
from itertools import product
21
- import yaml
22
21
23
22
import pytest
24
23
@@ -226,24 +225,6 @@ def test_parse_mp_parameters_input_str_json():
226
225
os .remove (json_file_path )
227
226
228
227
229
- def test_parse_mp_parameters_input_str_yaml ():
230
- mp_parameters = {
231
- "partitions" : 1 ,
232
- "tensor_parallel_degree" : 2 ,
233
- "microbatches" : 1 ,
234
- "optimize" : "speed" ,
235
- "pipeline" : "interleaved" ,
236
- "ddp" : 1 ,
237
- "auto_partition" : False ,
238
- "default_partition" : 0 ,
239
- }
240
- yaml_file_path = "./params.yaml"
241
- with open (yaml_file_path , "x" ) as fp :
242
- yaml .dump (mp_parameters , fp )
243
- assert mp_parameters == fw_utils .parse_mp_parameters (yaml_file_path )
244
- os .remove (yaml_file_path )
245
-
246
-
247
228
def test_parse_mp_parameters_input_not_exit ():
248
229
with pytest .raises (ValueError ):
249
230
fw_utils .parse_mp_parameters (" !@#$%^&*()path probably in not there.!@#$%^&*()" )
You can’t perform that action at this time.
0 commit comments