Skip to content

Commit ae7a7e9

Browse files
committed
fix: Update HPO TrainingInputMode with s3_input InputMode
1 parent 6fad54b commit ae7a7e9

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,8 @@ def start_new(cls, estimator, inputs):
571571

572572
if isinstance(inputs, s3_input):
573573
if 'InputMode' in inputs.config:
574+
logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.'
575+
.format(inputs.config['InputMode']))
574576
train_args['input_mode'] = inputs.config['InputMode']
575577

576578
if estimator.enable_network_isolation():

src/sagemaker/tuner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import importlib
1616
import inspect
1717
import json
18+
import logging
1819
from enum import Enum
1920

2021
import sagemaker
@@ -26,6 +27,7 @@
2627
from sagemaker.parameter import (CategoricalParameter, ContinuousParameter,
2728
IntegerParameter, ParameterRange)
2829
from sagemaker.session import Session
30+
from sagemaker.session import s3_input
2931
from sagemaker.utils import base_name_from_image, name_from_base, to_str
3032

3133
AMAZON_ESTIMATOR_MODULE = 'sagemaker'
@@ -640,6 +642,12 @@ def start_new(cls, tuner, inputs):
640642
tuner_args['warm_start_config'] = warm_start_config_req
641643
tuner_args['early_stopping_type'] = tuner.early_stopping_type
642644

645+
if isinstance(inputs, s3_input):
646+
if 'InputMode' in inputs.config:
647+
logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.'
648+
.format(inputs.config['InputMode']))
649+
tuner_args['input_mode'] = inputs.config['InputMode']
650+
643651
if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):
644652
tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn
645653
else:

tests/unit/test_tuner.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.tuner import (_TuningJob, create_identical_dataset_and_algorithm_tuner,
2929
create_transfer_learning_tuner, HyperparameterTuner, WarmStartConfig,
3030
WarmStartTypes)
31+
from sagemaker.session import s3_input
3132

3233
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
3334
MODEL_DATA = "s3://bucket/model.tar.gz"
@@ -286,6 +287,31 @@ def test_fit_mxnet_with_vpc_config(sagemaker_session, tuner):
286287
assert tune_kwargs['vpc_config'] == {'Subnets': subnets, 'SecurityGroupIds': security_group_ids}
287288

288289

290+
def test_s3_input_mode(sagemaker_session, tuner):
291+
expected_input_mode = 'Pipe'
292+
293+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
294+
mxnet = MXNet(entry_point=script_path,
295+
role=ROLE,
296+
framework_version=FRAMEWORK_VERSION,
297+
train_instance_count=TRAIN_INSTANCE_COUNT,
298+
train_instance_type=TRAIN_INSTANCE_TYPE,
299+
sagemaker_session=sagemaker_session)
300+
tuner.estimator = mxnet
301+
302+
tags = [{'Name': 'some-tag-without-a-value'}]
303+
tuner.tags = tags
304+
305+
hyperparameter_ranges = {'num_components': IntegerParameter(2, 4),
306+
'algorithm_mode': CategoricalParameter(['regular', 'randomized'])}
307+
tuner._hyperparameter_ranges = hyperparameter_ranges
308+
309+
tuner.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode))
310+
311+
actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode']
312+
assert actual_input_mode == expected_input_mode
313+
314+
289315
def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, tuner):
290316
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
291317
base_job_name='pca', sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)