Skip to content

Commit 70ebce5

Browse files
Yongyan Raonavinsoni
Yongyan Rao
authored andcommitted
feature: make estimator accept json and yaml file as modelparallel config
1 parent a35a093 commit 70ebce5

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/sagemaker/fw_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import re
@@ -21,6 +22,7 @@
2122
import tempfile
2223
from collections import namedtuple
2324
from typing import Optional, Union, Dict
25+
import yaml
2426

2527
import sagemaker.image_uris
2628
from sagemaker.session_settings import SessionSettings
@@ -234,6 +236,45 @@ def validate_source_code_input_against_pipeline_variables(
234236
)
235237

236238

239+
def parse_mp_parameters(params):
240+
"""Parse the model parallelism parameters provided by the user.
241+
242+
Args:
243+
params: a string representing path to an existing config, or
244+
a config dict.
245+
246+
Returns:
247+
parsed: a dict of parsed config.
248+
249+
Raises:
250+
ValueError: if params is not a string or a dict, or
251+
the config file cannot be parsed as json or yaml.
252+
"""
253+
parsed = None
254+
if isinstance(params, dict):
255+
parsed = params
256+
elif os.path.exists(params):
257+
try:
258+
with open(params, "r") as fp:
259+
parsed = json.load(fp)
260+
except json.decoder.JSONDecodeError:
261+
try:
262+
with open(params, "r") as fp:
263+
parsed = yaml.load(fp)
264+
except yaml.YAMLError:
265+
pass
266+
else:
267+
raise ValueError(
268+
f"Expected a string path to an existing modelparallel config, or a dictionary. "
269+
f"Received: {params}."
270+
)
271+
272+
if parsed is None:
273+
raise ValueError(f"Cannot parse {params} as a json or yaml file.")
274+
275+
return parsed
276+
277+
237278
def get_mp_parameters(distribution):
238279
"""Get the model parallelism parameters provided by the user.
239280
@@ -250,6 +291,7 @@ def get_mp_parameters(distribution):
250291
mp_dict = {}
251292
if mp_dict.get("enabled", False) is True:
252293
params = mp_dict.get("parameters", {})
294+
params = parse_mp_parameters(params)
253295
validate_mp_config(params)
254296
return params
255297
return None

tests/unit/test_fw_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16+
import json
1617
import os
1718
import tarfile
1819
from contextlib import contextmanager
1920
from itertools import product
21+
import yaml
2022

2123
import pytest
2224

@@ -192,6 +194,61 @@ def test_validate_source_dir_file_not_in_dir():
192194
fw_utils.validate_source_dir(script, directory)
193195

194196

197+
def test_parse_mp_parameters_input_dict():
198+
mp_parameters = {
199+
"partitions": 1,
200+
"tensor_parallel_degree": 2,
201+
"microbatches": 1,
202+
"optimize": "speed",
203+
"pipeline": "interleaved",
204+
"ddp": 1,
205+
"auto_partition": False,
206+
"default_partition": 0,
207+
}
208+
assert mp_parameters == fw_utils.parse_mp_parameters(mp_parameters)
209+
210+
211+
def test_parse_mp_parameters_input_str_json():
212+
mp_parameters = {
213+
"partitions": 1,
214+
"tensor_parallel_degree": 2,
215+
"microbatches": 1,
216+
"optimize": "speed",
217+
"pipeline": "interleaved",
218+
"ddp": 1,
219+
"auto_partition": False,
220+
"default_partition": 0,
221+
}
222+
json_file_path = "./params.json"
223+
with open(json_file_path, "x") as fp:
224+
json.dump(mp_parameters, fp)
225+
assert mp_parameters == fw_utils.parse_mp_parameters(json_file_path)
226+
os.remove(json_file_path)
227+
228+
229+
def test_parse_mp_parameters_input_str_yaml():
230+
mp_parameters = {
231+
"partitions": 1,
232+
"tensor_parallel_degree": 2,
233+
"microbatches": 1,
234+
"optimize": "speed",
235+
"pipeline": "interleaved",
236+
"ddp": 1,
237+
"auto_partition": False,
238+
"default_partition": 0,
239+
}
240+
yaml_file_path = "./params.yaml"
241+
with open(yaml_file_path, "x") as fp:
242+
yaml.dump(mp_parameters, fp)
243+
assert mp_parameters == fw_utils.parse_mp_parameters(yaml_file_path)
244+
os.remove(yaml_file_path)
245+
246+
247+
def test_parse_mp_parameters_input_not_exit():
248+
with pytest.raises(ValueError):
249+
fw_utils.parse_mp_parameters(" !@#$%^&*()path probably in not there.!@#$%^&*()")
250+
251+
195252
def test_tar_and_upload_dir_not_s3(sagemaker_session):
196253
bucket = "mybucket"
197254
s3_key_prefix = "something/source"

0 commit comments

Comments
 (0)