Skip to content

Commit 5418d5a

Browse files
committed
Add another unit test
This should increase code coverage by testing if the optional training parameters are set by the user.
1 parent 6cace84 commit 5418d5a

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/unit/test_session.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,39 @@ def test_train_pack_to_request(sagemaker_session):
229229
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
230230

231231

232+
def test_train_pack_to_request_with_optional_params(sagemaker_session):
233+
in_config = [{
234+
'ChannelName': 'training',
235+
'DataSource': {
236+
'S3DataSource': {
237+
'S3DataDistributionType': 'FullyReplicated',
238+
'S3DataType': 'S3Prefix',
239+
'S3Uri': S3_INPUT_URI
240+
}
241+
}
242+
}]
243+
244+
out_config = {'S3OutputPath': S3_OUTPUT}
245+
246+
resource_config = {'InstanceCount': INSTANCE_COUNT,
247+
'InstanceType': INSTANCE_TYPE,
248+
'VolumeSizeInGB': MAX_SIZE}
249+
250+
stop_cond = {'MaxRuntimeInSeconds': MAX_TIME}
251+
252+
hyperparameters = {'foo': 'bar'}
253+
tags = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
254+
255+
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
256+
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
257+
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=tags)
258+
259+
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
260+
261+
assert actual_train_args['HyperParameters'] == hyperparameters
262+
assert actual_train_args['Tags'] == tags
263+
264+
232265
@patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO)
233266
def test_color_wrap(bio):
234267
color_wrap = sagemaker.logs.ColorWrap()

0 commit comments

Comments
 (0)