Skip to content

Commit 8735745

Browse files
yongyanraoYongyan Rao
authored andcommitted
feature: make estimator accept json file as modelparallel config (aws#3265)
Co-authored-by: Yongyan Rao <[email protected]>
1 parent 7017ce8 commit 8735745

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

src/sagemaker/fw_utils.py

+37
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import re
@@ -234,6 +235,41 @@ def validate_source_code_input_against_pipeline_variables(
234235
)
235236

236237

238+
def parse_mp_parameters(params):
239+
"""Parse the model parallelism parameters provided by the user.
240+
241+
Args:
242+
params: a string representing path to an existing config, or
243+
a config dict.
244+
245+
Returns:
246+
parsed: a dict of parsed config.
247+
248+
Raises:
249+
ValueError: if params is not a string or a dict, or
250+
the config file cannot be parsed as json.
251+
"""
252+
parsed = None
253+
if isinstance(params, dict):
254+
parsed = params
255+
elif os.path.exists(params):
256+
try:
257+
with open(params, "r") as fp:
258+
parsed = json.load(fp)
259+
except json.decoder.JSONDecodeError:
260+
pass
261+
else:
262+
raise ValueError(
263+
f"Expected a string path to an existing modelparallel config, or a dictionary. "
264+
f"Received: {params}."
265+
)
266+
267+
if parsed is None:
268+
raise ValueError(f"Cannot parse {params} as a json file.")
269+
270+
return parsed
271+
272+
237273
def get_mp_parameters(distribution):
238274
"""Get the model parallelism parameters provided by the user.
239275
@@ -250,6 +286,7 @@ def get_mp_parameters(distribution):
250286
mp_dict = {}
251287
if mp_dict.get("enabled", False) is True:
252288
params = mp_dict.get("parameters", {})
289+
params = parse_mp_parameters(params)
253290
validate_mp_config(params)
254291
return params
255292
return None

tests/unit/test_fw_utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16+
import json
1617
import os
1718
import tarfile
1819
from contextlib import contextmanager
@@ -192,6 +193,43 @@ def test_validate_source_dir_file_not_in_dir():
192193
fw_utils.validate_source_dir(script, directory)
193194

194195

196+
def test_parse_mp_parameters_input_dict():
197+
mp_parameters = {
198+
"partitions": 1,
199+
"tensor_parallel_degree": 2,
200+
"microbatches": 1,
201+
"optimize": "speed",
202+
"pipeline": "interleaved",
203+
"ddp": 1,
204+
"auto_partition": False,
205+
"default_partition": 0,
206+
}
207+
assert mp_parameters == fw_utils.parse_mp_parameters(mp_parameters)
208+
209+
210+
def test_parse_mp_parameters_input_str_json():
211+
mp_parameters = {
212+
"partitions": 1,
213+
"tensor_parallel_degree": 2,
214+
"microbatches": 1,
215+
"optimize": "speed",
216+
"pipeline": "interleaved",
217+
"ddp": 1,
218+
"auto_partition": False,
219+
"default_partition": 0,
220+
}
221+
json_file_path = "./params.json"
222+
with open(json_file_path, "x") as fp:
223+
json.dump(mp_parameters, fp)
224+
assert mp_parameters == fw_utils.parse_mp_parameters(json_file_path)
225+
os.remove(json_file_path)
226+
227+
228+
def test_parse_mp_parameters_input_not_exit():
229+
with pytest.raises(ValueError):
230+
fw_utils.parse_mp_parameters(" !@#$%^&*()path probably in not there.!@#$%^&*()")
231+
232+
195233
def test_tar_and_upload_dir_not_s3(sagemaker_session):
196234
bucket = "mybucket"
197235
s3_key_prefix = "something/source"

0 commit comments

Comments
 (0)