@@ -229,6 +229,39 @@ def test_train_pack_to_request(sagemaker_session):
229
229
'create_training_job' , (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS )
230
230
231
231
232
+ def test_train_pack_to_request_with_optional_params (sagemaker_session ):
233
+ in_config = [{
234
+ 'ChannelName' : 'training' ,
235
+ 'DataSource' : {
236
+ 'S3DataSource' : {
237
+ 'S3DataDistributionType' : 'FullyReplicated' ,
238
+ 'S3DataType' : 'S3Prefix' ,
239
+ 'S3Uri' : S3_INPUT_URI
240
+ }
241
+ }
242
+ }]
243
+
244
+ out_config = {'S3OutputPath' : S3_OUTPUT }
245
+
246
+ resource_config = {'InstanceCount' : INSTANCE_COUNT ,
247
+ 'InstanceType' : INSTANCE_TYPE ,
248
+ 'VolumeSizeInGB' : MAX_SIZE }
249
+
250
+ stop_cond = {'MaxRuntimeInSeconds' : MAX_TIME }
251
+
252
+ hyperparameters = {'foo' : 'bar' }
253
+ tags = [{'Name' : 'some-tag' , 'Value' : 'value-for-tag' }]
254
+
255
+ sagemaker_session .train (image = IMAGE , input_mode = 'File' , input_config = in_config , role = EXPANDED_ROLE ,
256
+ job_name = JOB_NAME , output_config = out_config , resource_config = resource_config ,
257
+ hyperparameters = hyperparameters , stop_condition = stop_cond , tags = tags )
258
+
259
+ _ , _ , actual_train_args = sagemaker_session .sagemaker_client .method_calls [0 ]
260
+
261
+ assert actual_train_args ['HyperParameters' ] == hyperparameters
262
+ assert actual_train_args ['Tags' ] == tags
263
+
264
+
232
265
@patch ('sys.stdout' , new_callable = io .BytesIO if six .PY2 else io .StringIO )
233
266
def test_color_wrap (bio ):
234
267
color_wrap = sagemaker .logs .ColorWrap ()
0 commit comments