Skip to content

Commit 369fb6d

Browse files
committed
Add unit test for s3_input InputMode
1 parent 747e4cc commit 369fb6d

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def start_new(cls, estimator, inputs):
570570
train_args['metric_definitions'] = estimator.metric_definitions
571571

572572
if isinstance(inputs, s3_input):
573-
if inputs.config['InputMode'] is not None:
573+
if 'InputMode' in inputs.config:
574574
train_args['input_mode'] = inputs.config['InputMode']
575575

576576
if estimator.enable_network_isolation():

tests/unit/test_estimator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,17 @@ def test_augmented_manifest(sagemaker_session):
315315
assert s3_data_source['AttributeNames'] == ['foo', 'bar']
316316

317317

318+
def test_s3_input_mode(sagemaker_session):
319+
expected_input_mode = 'Pipe'
320+
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
321+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
322+
enable_cloudwatch_metrics=True)
323+
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode))
324+
325+
actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode']
326+
assert actual_input_mode == expected_input_mode
327+
328+
318329
def test_shuffle_config(sagemaker_session):
319330
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
320331
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,

0 commit comments

Comments
 (0)