@@ -1241,8 +1241,9 @@ def test_custom_code_bucket(time, sagemaker_session):
1241
1241
1242
1242
expected_submit_dir = "s3://{}/{}" .format (code_bucket , expected_key )
1243
1243
_ , _ , train_kwargs = sagemaker_session .train .mock_calls [0 ]
1244
- assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == expected_submit_dir
1245
-
1244
+ assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == json .dumps (
1245
+ expected_submit_dir
1246
+ )
1246
1247
1247
1248
@patch ("time.strftime" , return_value = TIMESTAMP )
1248
1249
def test_custom_code_bucket_without_prefix (time , sagemaker_session ):
@@ -1264,7 +1265,9 @@ def test_custom_code_bucket_without_prefix(time, sagemaker_session):
1264
1265
1265
1266
expected_submit_dir = "s3://{}/{}" .format (code_bucket , expected_key )
1266
1267
_ , _ , train_kwargs = sagemaker_session .train .mock_calls [0 ]
1267
- assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == expected_submit_dir
1268
+ assert train_kwargs ["hyperparameters" ]["sagemaker_submit_directory" ] == json .dumps (
1269
+ expected_submit_dir
1270
+ )
1268
1271
1269
1272
1270
1273
def test_invalid_custom_code_bucket (sagemaker_session ):
@@ -1336,9 +1339,11 @@ def test_shuffle_config(sagemaker_session):
1336
1339
1337
1340
1338
1341
BASE_HP = {
1339
- "sagemaker_program" : SCRIPT_NAME ,
1340
- "sagemaker_submit_directory" : "s3://mybucket/{}/source/sourcedir.tar.gz" .format (JOB_NAME ),
1341
- "sagemaker_job_name" : JOB_NAME ,
1342
+ "sagemaker_program" : json .dumps (SCRIPT_NAME ),
1343
+ "sagemaker_submit_directory" : json .dumps (
1344
+ "s3://mybucket/{}/source/sourcedir.tar.gz" .format (JOB_NAME )
1345
+ ),
1346
+ "sagemaker_job_name" : json .dumps (JOB_NAME ),
1342
1347
}
1343
1348
1344
1349
@@ -1383,8 +1388,8 @@ def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session):
1383
1388
t .fit ("s3://{}" .format (uri ))
1384
1389
1385
1390
expected_hyperparameters = BASE_HP .copy ()
1386
- expected_hyperparameters ["sagemaker_container_log_level" ] = logging .INFO
1387
- expected_hyperparameters ["learning_rate" ] = 0.1
1391
+ expected_hyperparameters ["sagemaker_container_log_level" ] = str ( logging .INFO )
1392
+ expected_hyperparameters ["learning_rate" ] = json . dumps ( 0.1 )
1388
1393
expected_hyperparameters ["123" ] = json .dumps ([456 ])
1389
1394
expected_hyperparameters ["sagemaker_region" ] = '"us-west-2"'
1390
1395
@@ -1407,7 +1412,7 @@ def test_start_new_wait_called(strftime, sagemaker_session):
1407
1412
t .fit ("s3://{}" .format (uri ))
1408
1413
1409
1414
expected_hyperparameters = BASE_HP .copy ()
1410
- expected_hyperparameters ["sagemaker_container_log_level" ] = logging .INFO
1415
+ expected_hyperparameters ["sagemaker_container_log_level" ] = str ( logging .INFO )
1411
1416
expected_hyperparameters ["sagemaker_region" ] = '"us-west-2"'
1412
1417
1413
1418
actual_hyperparameter = sagemaker_session .method_calls [1 ][2 ]["hyperparameters" ]
0 commit comments