Skip to content

Commit dbdf106

Browse files
authored
Update test_estimator.py
Fixing tests to reflect desired behaviour.
1 parent 88ac6c6 commit dbdf106

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

tests/unit/test_estimator.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -1241,9 +1241,7 @@ def test_custom_code_bucket(time, sagemaker_session):
12411241

12421242
expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key)
12431243
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
1244-
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps(
1245-
expected_submit_dir
1246-
)
1244+
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir
12471245

12481246

12491247
@patch("time.strftime", return_value=TIMESTAMP)
@@ -1266,9 +1264,7 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session):
12661264

12671265
expected_submit_dir = "s3://{}/{}".format(code_bucket, expected_key)
12681266
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
1269-
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == json.dumps(
1270-
expected_submit_dir
1271-
)
1267+
assert train_kwargs["hyperparameters"]["sagemaker_submit_directory"] == expected_submit_dir
12721268

12731269

12741270
def test_invalid_custom_code_bucket(sagemaker_session):
@@ -1340,11 +1336,10 @@ def test_shuffle_config(sagemaker_session):
13401336

13411337

13421338
BASE_HP = {
1343-
"sagemaker_program": json.dumps(SCRIPT_NAME),
1344-
"sagemaker_submit_directory": json.dumps(
1345-
"s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME)
1346-
),
1347-
"sagemaker_job_name": json.dumps(JOB_NAME),
1339+
"sagemaker_program": SCRIPT_NAME,
1340+
"sagemaker_submit_directory":
1341+
"s3://mybucket/{}/source/sourcedir.tar.gz".format(JOB_NAME),
1342+
"sagemaker_job_name": JOB_NAME,
13481343
}
13491344

13501345

@@ -1389,8 +1384,8 @@ def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session):
13891384
t.fit("s3://{}".format(uri))
13901385

13911386
expected_hyperparameters = BASE_HP.copy()
1392-
expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO)
1393-
expected_hyperparameters["learning_rate"] = json.dumps(0.1)
1387+
expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO
1388+
expected_hyperparameters["learning_rate"] = 0.1
13941389
expected_hyperparameters["123"] = json.dumps([456])
13951390
expected_hyperparameters["sagemaker_region"] = '"us-west-2"'
13961391

@@ -1413,7 +1408,7 @@ def test_start_new_wait_called(strftime, sagemaker_session):
14131408
t.fit("s3://{}".format(uri))
14141409

14151410
expected_hyperparameters = BASE_HP.copy()
1416-
expected_hyperparameters["sagemaker_container_log_level"] = str(logging.INFO)
1411+
expected_hyperparameters["sagemaker_container_log_level"] = logging.INFO
14171412
expected_hyperparameters["sagemaker_region"] = '"us-west-2"'
14181413

14191414
actual_hyperparameter = sagemaker_session.method_calls[1][2]["hyperparameters"]

0 commit comments

Comments
 (0)