Skip to content

Commit e7d7085

Browse files
author
Christian Osendorfer
committed
Fix for aws#2949.
Hyperparameters are only json encoded at the end of setting up a Sagemaker job.
1 parent 7cd161a commit e7d7085

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

src/sagemaker/estimator.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -2760,11 +2760,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
27602760
init_params = super(Framework, cls)._prepare_init_params_from_job_description(
27612761
job_details, model_channel_name
27622762
)
2763-
2764-
init_params["entry_point"] = init_params["hyperparameters"].get(SCRIPT_PARAM_NAME)
2765-
init_params["source_dir"] = init_params["hyperparameters"].get(DIR_PARAM_NAME)
2766-
init_params["container_log_level"] = init_params["hyperparameters"].get(
2767-
CONTAINER_LOG_LEVEL_PARAM_NAME)
2763+
init_params["entry_point"] = json.loads(
2764+
init_params["hyperparameters"].get(SCRIPT_PARAM_NAME)
2765+
)
2766+
init_params["source_dir"] = json.loads(init_params["hyperparameters"].get(DIR_PARAM_NAME))
2767+
init_params["container_log_level"] = json.loads(
2768+
init_params["hyperparameters"].get(CONTAINER_LOG_LEVEL_PARAM_NAME)
2769+
)
27682770

27692771
hyperparameters = {}
27702772
for k, v in init_params["hyperparameters"].items():

src/sagemaker/local/image.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def _prepare_training_volumes(
493493
# If there is a training script directory and it is a local directory,
494494
# mount it to the container.
495495
if sagemaker.estimator.DIR_PARAM_NAME in hyperparameters:
496-
training_dir = hyperparameters[sagemaker.estimator.DIR_PARAM_NAME]
496+
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
497497
parsed_uri = urlparse(training_dir)
498498
if parsed_uri.scheme == "file":
499499
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
@@ -579,7 +579,7 @@ def _update_local_src_path(self, params, key):
579579
The updated parameters.
580580
"""
581581
if key in params:
582-
src_dir = params[key]
582+
src_dir = json.loads(params[key])
583583
parsed_uri = urlparse(src_dir)
584584
if parsed_uri.scheme == "file":
585585
new_params = params.copy()

tests/unit/test_estimator.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -1241,8 +1241,9 @@ def test_custom_code_bucket(time, sagemaker_session):
12411241

12421242
expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key)
12431243
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
1244-
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir
1245-
1244+
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps(
1245+
expected_submit_dir
1246+
)
12461247

12471248
@patch("time.strftime", return_value=TIMESTAMP)
12481249
def test_custom_code_bucket_without_prefix(time, sagemaker_session):
@@ -1264,7 +1265,9 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session):
12641265

12651266
expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key)
12661267
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
1267-
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir
1268+
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps(
1269+
expected_submit_dir
1270+
)
12681271

12691272

12701273
def test_invalid_custom_code_bucket(sagemaker_session):
@@ -1336,9 +1339,11 @@ def test_shuffle_config(sagemaker_session):
13361339

13371340

13381341
BASE_HP = {
1339-
"sagemaker_program": SCRIPT_NAME,
1340-
"sagemaker_submit_directory": "s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME),
1341-
"sagemaker_job_name": JOB_NAME,
1342+
"sagemaker_program": json.dumps(SCRIPT_NAME),
1343+
"sagemaker_submit_directory": json.dumps(
1344+
"s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME)
1345+
),
1346+
"sagemaker_job_name": json.dumps(JOB_NAME),
13421347
}
13431348

13441349

@@ -1383,8 +1388,8 @@ def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session):
13831388
t.fit("s3://{}".format(uri))
13841389

13851390
expected_hyperparameters = BASE_HP.copy()
1386-
expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO
1387-
expected_hyperparameters["learning_rate"] = 0.1
1391+
expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO)
1392+
expected_hyperparameters["learning_rate"] = json.dumps(0.1)
13881393
expected_hyperparameters["123"] = json.dumps([456])
13891394
expected_hyperparameters["sagemaker_region"] = '"us-west-2"'
13901395

@@ -1407,7 +1412,7 @@ def test_start_new_wait_called(strftime, sagemaker_session):
14071412
t.fit("s3://{}".format(uri))
14081413

14091414
expected_hyperparameters = BASE_HP.copy()
1410-
expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO
1415+
expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO)
14111416
expected_hyperparameters["sagemaker_region"] = '"us-west-2"'
14121417

14131418
actual_hyperparameter = sagemaker_session.method_calls[1][2]["hyperparameters"]

0 commit comments

Comments
 (0)