Skip to content

Commit c121085

Browse files
author
Yongyan Rao
committed
feature: make estimator accept json and yaml file as modelparallel config
1 parent 7d30d8c commit c121085

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/sagemaker/fw_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import re
@@ -21,6 +22,7 @@
2122
import tempfile
2223
from collections import namedtuple
2324
from typing import Optional, Union, Dict
25+
import yaml
2426

2527
import sagemaker.image_uris
2628
from sagemaker.session_settings import SessionSettings
@@ -188,6 +190,45 @@ def validate_source_code_input_against_pipeline_variables(
188190
)
189191

190192

193+
def parse_mp_parameters(params):
194+
"""Parse the model parallelism parameters provided by the user.
195+
196+
Args:
197+
params: a string representing path to an existing config, or
198+
a config dict.
199+
200+
Returns:
201+
parsed: a dict of parsed config.
202+
203+
Raises:
204+
ValueError: if params is not a string or a dict, or
205+
the config file cannot be parsed as json or yaml.
206+
"""
207+
parsed = None
208+
if isinstance(params, dict):
209+
parsed = params
210+
elif os.path.exists(params):
211+
try:
212+
with open(params, "r") as fp:
213+
parsed = json.load(fp)
214+
except json.decoder.JSONDecodeError:
215+
try:
216+
with open(params, "r") as fp:
217+
parsed = yaml.load(fp)
218+
except yaml.YAMLError:
219+
pass
220+
else:
221+
raise ValueError(
222+
f"Expected a string path to an existing modelparallel config, or a dictionary. "
223+
f"Received: {params}."
224+
)
225+
226+
if parsed is None:
227+
raise ValueError(f"Cannot parse {params} as a json or yaml file.")
228+
229+
return parsed
230+
231+
191232
def get_mp_parameters(distribution):
192233
"""Get the model parallelism parameters provided by the user.
193234
@@ -204,6 +245,7 @@ def get_mp_parameters(distribution):
204245
mp_dict = {}
205246
if mp_dict.get("enabled", False) is True:
206247
params = mp_dict.get("parameters", {})
248+
params = parse_mp_parameters(params)
207249
validate_mp_config(params)
208250
return params
209251
return None

tests/unit/test_fw_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16+
import json
1617
import os
1718
import tarfile
1819
from contextlib import contextmanager
1920
from itertools import product
21+
import yaml
2022

2123
import pytest
2224

@@ -201,6 +203,61 @@ def test_validate_source_dir_file_not_in_dir():
201203
fw_utils.validate_source_dir(script, directory)
202204

203205

206+
def test_parse_mp_parameters_input_dict():
207+
mp_parameters = {
208+
"partitions": 1,
209+
"tensor_parallel_degree": 2,
210+
"microbatches": 1,
211+
"optimize": "speed",
212+
"pipeline": "interleaved",
213+
"ddp": 1,
214+
"auto_partition": False,
215+
"default_partition": 0,
216+
}
217+
assert mp_parameters == fw_utils.parse_mp_parameters(mp_parameters)
218+
219+
220+
def test_parse_mp_parameters_input_str_json():
221+
mp_parameters = {
222+
"partitions": 1,
223+
"tensor_parallel_degree": 2,
224+
"microbatches": 1,
225+
"optimize": "speed",
226+
"pipeline": "interleaved",
227+
"ddp": 1,
228+
"auto_partition": False,
229+
"default_partition": 0,
230+
}
231+
json_file_path = "./params.json"
232+
with open(json_file_path, "x") as fp:
233+
json.dump(mp_parameters, fp)
234+
assert mp_parameters == fw_utils.parse_mp_parameters(json_file_path)
235+
os.remove(json_file_path)
236+
237+
238+
def test_parse_mp_parameters_input_str_yaml():
239+
mp_parameters = {
240+
"partitions": 1,
241+
"tensor_parallel_degree": 2,
242+
"microbatches": 1,
243+
"optimize": "speed",
244+
"pipeline": "interleaved",
245+
"ddp": 1,
246+
"auto_partition": False,
247+
"default_partition": 0,
248+
}
249+
yaml_file_path = "./params.yaml"
250+
with open(yaml_file_path, "x") as fp:
251+
yaml.dump(mp_parameters, fp)
252+
assert mp_parameters == fw_utils.parse_mp_parameters(yaml_file_path)
253+
os.remove(yaml_file_path)
254+
255+
256+
def test_parse_mp_parameters_input_not_exit():
257+
with pytest.raises(ValueError):
258+
fw_utils.parse_mp_parameters(" !@#$%^&*()path probably in not there.!@#$%^&*()")
259+
260+
204261
def test_tar_and_upload_dir_not_s3(sagemaker_session):
205262
bucket = "mybucket"
206263
s3_key_prefix = "something/source"

0 commit comments

Comments
 (0)