Skip to content

Commit 09b7deb

Browse files
Yongyan Raoclaytonparnell
Yongyan Rao
authored andcommitted
feature: make estimator accept json and yaml file as modelparallel config
1 parent f2d5e41 commit 09b7deb

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/sagemaker/fw_utils.py

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Utility methods used by framework classes"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import re
@@ -21,6 +22,7 @@
2122
import tempfile
2223
from collections import namedtuple
2324
from typing import Optional, Union, Dict
25+
import yaml
2426

2527
import sagemaker.image_uris
2628
from sagemaker.session_settings import SessionSettings
@@ -208,6 +210,45 @@ def validate_source_code_input_against_pipeline_variables(
208210
)
209211

210212

213+
def parse_mp_parameters(params):
214+
"""Parse the model parallelism parameters provided by the user.
215+
216+
Args:
217+
params: a string representing path to an existing config, or
218+
a config dict.
219+
220+
Returns:
221+
parsed: a dict of parsed config.
222+
223+
Raises:
224+
ValueError: if params is not a string or a dict, or
225+
the config file cannot be parsed as json or yaml.
226+
"""
227+
parsed = None
228+
if isinstance(params, dict):
229+
parsed = params
230+
elif os.path.exists(params):
231+
try:
232+
with open(params, "r") as fp:
233+
parsed = json.load(fp)
234+
except json.decoder.JSONDecodeError:
235+
try:
236+
with open(params, "r") as fp:
237+
parsed = yaml.load(fp)
238+
except yaml.YAMLError:
239+
pass
240+
else:
241+
raise ValueError(
242+
f"Expected a string path to an existing modelparallel config, or a dictionary. "
243+
f"Received: {params}."
244+
)
245+
246+
if parsed is None:
247+
raise ValueError(f"Cannot parse {params} as a json or yaml file.")
248+
249+
return parsed
250+
251+
211252
def get_mp_parameters(distribution):
212253
"""Get the model parallelism parameters provided by the user.
213254
@@ -224,6 +265,7 @@ def get_mp_parameters(distribution):
224265
mp_dict = {}
225266
if mp_dict.get("enabled", False) is True:
226267
params = mp_dict.get("parameters", {})
268+
params = parse_mp_parameters(params)
227269
validate_mp_config(params)
228270
return params
229271
return None

tests/unit/test_fw_utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from __future__ import absolute_import
1414

1515
import inspect
16+
import json
1617
import os
1718
import tarfile
1819
from contextlib import contextmanager
1920
from itertools import product
21+
import yaml
2022

2123
import pytest
2224

@@ -192,6 +194,61 @@ def test_validate_source_dir_file_not_in_dir():
192194
fw_utils.validate_source_dir(script, directory)
193195

194196

197+
def test_parse_mp_parameters_input_dict():
198+
mp_parameters = {
199+
"partitions": 1,
200+
"tensor_parallel_degree": 2,
201+
"microbatches": 1,
202+
"optimize": "speed",
203+
"pipeline": "interleaved",
204+
"ddp": 1,
205+
"auto_partition": False,
206+
"default_partition": 0,
207+
}
208+
assert mp_parameters == fw_utils.parse_mp_parameters(mp_parameters)
209+
210+
211+
def test_parse_mp_parameters_input_str_json():
212+
mp_parameters = {
213+
"partitions": 1,
214+
"tensor_parallel_degree": 2,
215+
"microbatches": 1,
216+
"optimize": "speed",
217+
"pipeline": "interleaved",
218+
"ddp": 1,
219+
"auto_partition": False,
220+
"default_partition": 0,
221+
}
222+
json_file_path = "./params.json"
223+
with open(json_file_path, "x") as fp:
224+
json.dump(mp_parameters, fp)
225+
assert mp_parameters == fw_utils.parse_mp_parameters(json_file_path)
226+
os.remove(json_file_path)
227+
228+
229+
def test_parse_mp_parameters_input_str_yaml():
230+
mp_parameters = {
231+
"partitions": 1,
232+
"tensor_parallel_degree": 2,
233+
"microbatches": 1,
234+
"optimize": "speed",
235+
"pipeline": "interleaved",
236+
"ddp": 1,
237+
"auto_partition": False,
238+
"default_partition": 0,
239+
}
240+
yaml_file_path = "./params.yaml"
241+
with open(yaml_file_path, "x") as fp:
242+
yaml.dump(mp_parameters, fp)
243+
assert mp_parameters == fw_utils.parse_mp_parameters(yaml_file_path)
244+
os.remove(yaml_file_path)
245+
246+
247+
def test_parse_mp_parameters_input_not_exit():
248+
with pytest.raises(ValueError):
249+
fw_utils.parse_mp_parameters(" !@#$%^&*()path probably in not there.!@#$%^&*()")
250+
251+
195252
def test_tar_and_upload_dir_not_s3(sagemaker_session):
196253
bucket = "mybucket"
197254
s3_key_prefix = "something/source"

0 commit comments

Comments
 (0)