File tree 2 files changed +3
-27
lines changed
2 files changed +3
-27
lines changed Original file line number Diff line number Diff line change 22
22
import tempfile
23
23
from collections import namedtuple
24
24
from typing import Optional , Union , Dict
25
- import yaml
26
25
27
26
import sagemaker .image_uris
28
27
from sagemaker .session_settings import SessionSettings
@@ -202,7 +201,7 @@ def parse_mp_parameters(params):
202
201
203
202
Raises:
204
203
ValueError: if params is not a string or a dict, or
205
- the config file cannot be parsed as json or yaml .
204
+ the config file cannot be parsed as json.
206
205
"""
207
206
parsed = None
208
207
if isinstance (params , dict ):
@@ -212,19 +211,15 @@ def parse_mp_parameters(params):
212
211
with open (params , "r" ) as fp :
213
212
parsed = json .load (fp )
214
213
except json .decoder .JSONDecodeError :
215
- try :
216
- with open (params , "r" ) as fp :
217
- parsed = yaml .load (fp )
218
- except yaml .YAMLError :
219
- pass
214
+ pass
220
215
else :
221
216
raise ValueError (
222
217
f"Expected a string path to an existing modelparallel config, or a dictionary. "
223
218
f"Received: { params } ."
224
219
)
225
220
226
221
if parsed is None :
227
- raise ValueError (f"Cannot parse { params } as a json or yaml file." )
222
+ raise ValueError (f"Cannot parse { params } as a json file." )
228
223
229
224
return parsed
230
225
Original file line number Diff line number Diff line change 18
18
import tarfile
19
19
from contextlib import contextmanager
20
20
from itertools import product
21
- import yaml
22
21
23
22
import pytest
24
23
@@ -235,24 +234,6 @@ def test_parse_mp_parameters_input_str_json():
235
234
os .remove (json_file_path )
236
235
237
236
238
- def test_parse_mp_parameters_input_str_yaml ():
239
- mp_parameters = {
240
- "partitions" : 1 ,
241
- "tensor_parallel_degree" : 2 ,
242
- "microbatches" : 1 ,
243
- "optimize" : "speed" ,
244
- "pipeline" : "interleaved" ,
245
- "ddp" : 1 ,
246
- "auto_partition" : False ,
247
- "default_partition" : 0 ,
248
- }
249
- yaml_file_path = "./params.yaml"
250
- with open (yaml_file_path , "x" ) as fp :
251
- yaml .dump (mp_parameters , fp )
252
- assert mp_parameters == fw_utils .parse_mp_parameters (yaml_file_path )
253
- os .remove (yaml_file_path )
254
-
255
-
256
237
def test_parse_mp_parameters_input_not_exit ():
257
238
with pytest .raises (ValueError ):
258
239
fw_utils .parse_mp_parameters (" !@#$%^&*()path probably in not there.!@#$%^&*()" )
You can’t perform that action at this time.
0 commit comments