11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ from copy import deepcopy
14
15
15
16
import logging
16
17
import json
@@ -3825,6 +3826,12 @@ def test_script_mode_estimator_same_calls_as_framework(
3825
3826
3826
3827
model_uri = "s3://someprefix2/models/model.tar.gz"
3827
3828
training_data_uri = "s3://bucket/mydata"
3829
+ hyperparameters = {
3830
+ "int_hyperparam" : 1 ,
3831
+ "string_hyperparam" : "hello" ,
3832
+ "stringified_numeric_hyperparam" : "44" ,
3833
+ "float_hyperparam" : 1.234 ,
3834
+ }
3828
3835
3829
3836
generic_estimator = Estimator (
3830
3837
entry_point = SCRIPT_PATH ,
@@ -3838,6 +3845,7 @@ def test_script_mode_estimator_same_calls_as_framework(
3838
3845
model_uri = model_uri ,
3839
3846
dependencies = [],
3840
3847
debugger_hook_config = {},
3848
+ hyperparameters = deepcopy (hyperparameters ),
3841
3849
)
3842
3850
generic_estimator .fit (training_data_uri )
3843
3851
@@ -3858,6 +3866,7 @@ def test_script_mode_estimator_same_calls_as_framework(
3858
3866
model_uri = model_uri ,
3859
3867
dependencies = [],
3860
3868
debugger_hook_config = {},
3869
+ hyperparameters = deepcopy (hyperparameters ),
3861
3870
)
3862
3871
framework_estimator .fit (training_data_uri )
3863
3872
@@ -4394,3 +4403,51 @@ def test_insert_invalid_source_code_args():
4394
4403
assert (
4395
4404
"The entry_point should not be a pipeline variable " "when source_dir is a local path"
4396
4405
) in str (err .value )
4406
+
4407
+
4408
+ @patch ("time.time" , return_value = TIME )
4409
+ @patch ("sagemaker.estimator.tar_and_upload_dir" )
4410
+ @patch ("sagemaker.model.Model._upload_code" )
4411
+ def test_script_mode_estimator_escapes_hyperparameters_as_json (
4412
+ patched_upload_code , patched_tar_and_upload_dir , sagemaker_session
4413
+ ):
4414
+ patched_tar_and_upload_dir .return_value = UploadedCode (
4415
+ s3_prefix = "s3://%s/%s" % ("bucket" , "key" ), script_name = "script_name"
4416
+ )
4417
+ sagemaker_session .boto_region_name = REGION
4418
+
4419
+ instance_type = "ml.p2.xlarge"
4420
+ instance_count = 1
4421
+
4422
+ training_data_uri = "s3://bucket/mydata"
4423
+
4424
+ jumpstart_source_dir = f"s3://{ list (JUMPSTART_BUCKET_NAME_SET )[0 ]} /source_dirs/source.tar.gz"
4425
+
4426
+ hyperparameters = {
4427
+ "int_hyperparam" : 1 ,
4428
+ "string_hyperparam" : "hello" ,
4429
+ "stringified_numeric_hyperparam" : "44" ,
4430
+ "float_hyperparam" : 1.234 ,
4431
+ }
4432
+
4433
+ generic_estimator = Estimator (
4434
+ entry_point = SCRIPT_PATH ,
4435
+ role = ROLE ,
4436
+ region = REGION ,
4437
+ sagemaker_session = sagemaker_session ,
4438
+ instance_count = instance_count ,
4439
+ instance_type = instance_type ,
4440
+ source_dir = jumpstart_source_dir ,
4441
+ image_uri = IMAGE_URI ,
4442
+ model_uri = MODEL_DATA ,
4443
+ hyperparameters = hyperparameters ,
4444
+ )
4445
+ generic_estimator .fit (training_data_uri )
4446
+
4447
+ formatted_hyperparams = EstimatorBase ._json_encode_hyperparameters (hyperparameters )
4448
+
4449
+ assert (
4450
+ set (formatted_hyperparams .items ())
4451
+ - set (sagemaker_session .train .call_args_list [0 ][1 ]["hyperparameters" ].items ())
4452
+ == set ()
4453
+ )
0 commit comments