Skip to content

Commit a741609

Browse files
author
Yongyan Rao
committed
feature: make estimator accept json and yaml file as modelparallel config
1 parent 011539a commit a741609

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/sagemaker/fw_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import re
@@ -21,6 +22,7 @@
2122
import tempfile
2223
from collections import namedtuple
2324
from typing import Optional
25+
import yaml
2426

2527
import sagemaker.image_uris
2628
from sagemaker.session_settings import SessionSettings
@@ -124,6 +126,45 @@ def validate_source_dir(script, directory):
124126
return True
125127

126128

129+
def parse_mp_parameters(params):
130+
"""Parse the model parallelism parameters provided by the user.
131+
132+
Args:
133+
params: a string representing path to an existing config, or
134+
a config dict.
135+
136+
Returns:
137+
parsed: a dict of parsed config.
138+
139+
Raises:
140+
ValueError: if params is not a string or a dict, or
141+
the config file cannot be parsed as json or yaml.
142+
"""
143+
parsed = None
144+
if isinstance(params, dict):
145+
parsed = params
146+
elif os.path.exists(params):
147+
try:
148+
with open(params, "r") as fp:
149+
parsed = json.load(fp)
150+
except json.decoder.JSONDecodeError:
151+
try:
152+
with open(params, "r") as fp:
153+
parsed = yaml.load(fp)
154+
except yaml.YAMLError:
155+
pass
156+
else:
157+
raise ValueError(
158+
f"Expected a string path to an existing modelparallel config, or a dictionary. "
159+
f"Received: {params}."
160+
)
161+
162+
if parsed is None:
163+
raise ValueError(f"Cannot parse {params} as a json or yaml file.")
164+
165+
return parsed
166+
167+
127168
def get_mp_parameters(distribution):
128169
"""Get the model parallelism parameters provided by the user.
129170
@@ -140,6 +181,7 @@ def get_mp_parameters(distribution):
140181
mp_dict = {}
141182
if mp_dict.get("enabled", False) is True:
142183
params = mp_dict.get("parameters", {})
184+
params = parse_mp_parameters(params)
143185
validate_mp_config(params)
144186
return params
145187
return None

tests/unit/test_fw_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16+
import json
1617
import os
1718
import tarfile
1819
from contextlib import contextmanager
1920
from itertools import product
21+
import yaml
2022

2123
import pytest
2224

@@ -192,6 +194,61 @@ def test_validate_source_dir_file_not_in_dir():
192194
fw_utils.validate_source_dir(script, directory)
193195

194196

197+
def test_parse_mp_parameters_input_dict():
198+
mp_parameters = {
199+
"partitions": 1,
200+
"tensor_parallel_degree": 2,
201+
"microbatches": 1,
202+
"optimize": "speed",
203+
"pipeline": "interleaved",
204+
"ddp": 1,
205+
"auto_partition": False,
206+
"default_partition": 0,
207+
}
208+
assert mp_parameters == fw_utils.parse_mp_parameters(mp_parameters)
209+
210+
211+
def test_parse_mp_parameters_input_str_json():
212+
mp_parameters = {
213+
"partitions": 1,
214+
"tensor_parallel_degree": 2,
215+
"microbatches": 1,
216+
"optimize": "speed",
217+
"pipeline": "interleaved",
218+
"ddp": 1,
219+
"auto_partition": False,
220+
"default_partition": 0,
221+
}
222+
json_file_path = "./params.json"
223+
with open(json_file_path, "x") as fp:
224+
json.dump(mp_parameters, fp)
225+
assert mp_parameters == fw_utils.parse_mp_parameters(json_file_path)
226+
os.remove(json_file_path)
227+
228+
229+
def test_parse_mp_parameters_input_str_yaml():
230+
mp_parameters = {
231+
"partitions": 1,
232+
"tensor_parallel_degree": 2,
233+
"microbatches": 1,
234+
"optimize": "speed",
235+
"pipeline": "interleaved",
236+
"ddp": 1,
237+
"auto_partition": False,
238+
"default_partition": 0,
239+
}
240+
yaml_file_path = "./params.yaml"
241+
with open(yaml_file_path, "x") as fp:
242+
yaml.dump(mp_parameters, fp)
243+
assert mp_parameters == fw_utils.parse_mp_parameters(yaml_file_path)
244+
os.remove(yaml_file_path)
245+
246+
247+
def test_parse_mp_parameters_input_not_exit():
248+
with pytest.raises(ValueError):
249+
fw_utils.parse_mp_parameters(" !@#$%^&*()path probably in not there.!@#$%^&*()")
250+
251+
195252
def test_tar_and_upload_dir_not_s3(sagemaker_session):
196253
bucket = "mybucket"
197254
s3_key_prefix = "something/source"

0 commit comments

Comments
 (0)