@@ -1241,9 +1241,7 @@ def test_custom_code_bucket(time, sagemaker_session):
1241
1241
1242
1242
expected_submit_dir = "s3://{}/{}" .format (code_bucket , expected_key )
1243
1243
_ , _ , train_kwargs = sagemaker_session .train .mock_calls [0 ]
1244
- assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == json .dumps (
1245
- expected_submit_dir
1246
- )
1244
+ assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == expected_submit_dir
1247
1245
1248
1246
1249
1247
@patch ("time.strftime" , return_value = TIMESTAMP )
@@ -1266,9 +1264,7 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session):
1266
1264
1267
1265
expected_submit_dir = "s3://{}/{}" .format (code_bucket , expected_key )
1268
1266
_ , _ , train_kwargs = sagemaker_session .train .mock_calls [0 ]
1269
- assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == json .dumps (
1270
- expected_submit_dir
1271
- )
1267
+ assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == expected_submit_dir
1272
1268
1273
1269
1274
1270
def test_invalid_custom_code_bucket (sagemaker_session ):
@@ -1340,11 +1336,10 @@ def test_shuffle_config(sagemaker_session):
1340
1336
1341
1337
1342
1338
BASE_HP = {
1343
- "sagemaker_program" : json .dumps (SCRIPT_NAME ),
1344
- "sagemaker_submit_directory" : json .dumps (
1345
- "s3://mybucket/{}/source/sourcedir.tar.gz" .format (JOB_NAME )
1346
- ),
1347
- "sagemaker_job_name" : json .dumps (JOB_NAME ),
1339
+ "sagemaker_program" : SCRIPT_NAME ,
1340
+ "sagemaker_submit_directory" :
1341
+ "s3://mybucket/{}/source/sourcedir.tar.gz" .format (JOB_NAME ),
1342
+ "sagemaker_job_name" : JOB_NAME ,
1348
1343
}
1349
1344
1350
1345
@@ -1389,8 +1384,8 @@ def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session):
1389
1384
t .fit ("s3://{}" .format (uri ))
1390
1385
1391
1386
expected_hyperparameters = BASE_HP .copy ()
1392
- expected_hyperparameters ["sagemaker_container_log_level" ] = str ( logging .INFO )
1393
- expected_hyperparameters ["learning_rate" ] = json . dumps ( 0.1 )
1387
+ expected_hyperparameters ["sagemaker_container_log_level" ] = logging .INFO
1388
+ expected_hyperparameters ["learning_rate" ] = 0.1
1394
1389
expected_hyperparameters ["123" ] = json .dumps ([456 ])
1395
1390
expected_hyperparameters ["sagemaker_region" ] = '"us-west-2"'
1396
1391
@@ -1413,7 +1408,7 @@ def test_start_new_wait_called(strftime, sagemaker_session):
1413
1408
t .fit ("s3://{}" .format (uri ))
1414
1409
1415
1410
expected_hyperparameters = BASE_HP .copy ()
1416
- expected_hyperparameters ["sagemaker_container_log_level" ] = str ( logging .INFO )
1411
+ expected_hyperparameters ["sagemaker_container_log_level" ] = logging .INFO
1417
1412
expected_hyperparameters ["sagemaker_region" ] = '"us-west-2"'
1418
1413
1419
1414
actual_hyperparameter = sagemaker_session .method_calls [1 ][2 ]["hyperparameters" ]
0 commit comments