Skip to content

Commit 7529a22

Browse files
imujjwal96laurenyu
authored andcommitted
fix: update TrainingInputMode with s3_input InputMode (#776)
1 parent 555b8b1 commit 7529a22

File tree

5 files changed

+53
-2
lines changed

5 files changed

+53
-2
lines changed

src/sagemaker/estimator.py

+6
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,12 @@ def start_new(cls, estimator, inputs):
569569
train_args['tags'] = estimator.tags
570570
train_args['metric_definitions'] = estimator.metric_definitions
571571

572+
if isinstance(inputs, s3_input):
573+
if 'InputMode' in inputs.config:
574+
logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.'
575+
.format(inputs.config['InputMode']))
576+
train_args['input_mode'] = inputs.config['InputMode']
577+
572578
if estimator.enable_network_isolation():
573579
train_args['enable_network_isolation'] = True
574580

src/sagemaker/tuner.py

+8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import importlib
1616
import inspect
1717
import json
18+
import logging
1819
from enum import Enum
1920

2021
import sagemaker
@@ -26,6 +27,7 @@
2627
from sagemaker.parameter import (CategoricalParameter, ContinuousParameter,
2728
IntegerParameter, ParameterRange)
2829
from sagemaker.session import Session
30+
from sagemaker.session import s3_input
2931
from sagemaker.utils import base_name_from_image, name_from_base, to_str
3032

3133
AMAZON_ESTIMATOR_MODULE = 'sagemaker'
@@ -640,6 +642,12 @@ def start_new(cls, tuner, inputs):
640642
tuner_args['warm_start_config'] = warm_start_config_req
641643
tuner_args['early_stopping_type'] = tuner.early_stopping_type
642644

645+
if isinstance(inputs, s3_input):
646+
if 'InputMode' in inputs.config:
647+
logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.'
648+
.format(inputs.config['InputMode']))
649+
tuner_args['input_mode'] = inputs.config['InputMode']
650+
643651
if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator):
644652
tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn
645653
else:

tests/scripts/run-notebook-test.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ aws s3 --region us-west-2 cp ./dist/sagemaker-*.tar.gz s3://sagemaker-python-sdk
1010
aws s3 cp s3://sagemaker-mead-cli/mead-nb-test.tar.gz mead-nb-test.tar.gz
1111
tar -xzf mead-nb-test.tar.gz
1212
git clone --depth 1 https://github.com/awslabs/amazon-sagemaker-examples.git
13-
JAVA_HOME=$(get-java-home)
13+
export JAVA_HOME=$(get-java-home)
1414
echo "set JAVA_HOME=$JAVA_HOME"
15-
SAGEMAKER_ROLE_ARN=$(get-sagemaker-role-arn)
15+
export SAGEMAKER_ROLE_ARN=$(get-sagemaker-role-arn)
1616
echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN"
1717
./runtime/bin/mead-run-nb-test \
1818
--instance-type ml.c4.8xlarge \

tests/unit/test_estimator.py

+11
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,17 @@ def test_augmented_manifest(sagemaker_session):
315315
assert s3_data_source['AttributeNames'] == ['foo', 'bar']
316316

317317

318+
def test_s3_input_mode(sagemaker_session):
319+
expected_input_mode = 'Pipe'
320+
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
321+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
322+
enable_cloudwatch_metrics=True)
323+
fw.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode))
324+
325+
actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode']
326+
assert actual_input_mode == expected_input_mode
327+
328+
318329
def test_shuffle_config(sagemaker_session):
319330
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
320331
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,

tests/unit/test_tuner.py

+26
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.tuner import (_TuningJob, create_identical_dataset_and_algorithm_tuner,
2929
create_transfer_learning_tuner, HyperparameterTuner, WarmStartConfig,
3030
WarmStartTypes)
31+
from sagemaker.session import s3_input
3132

3233
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
3334
MODEL_DATA = "s3://bucket/model.tar.gz"
@@ -286,6 +287,31 @@ def test_fit_mxnet_with_vpc_config(sagemaker_session, tuner):
286287
assert tune_kwargs['vpc_config'] == {'Subnets': subnets, 'SecurityGroupIds': security_group_ids}
287288

288289

290+
def test_s3_input_mode(sagemaker_session, tuner):
291+
expected_input_mode = 'Pipe'
292+
293+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
294+
mxnet = MXNet(entry_point=script_path,
295+
role=ROLE,
296+
framework_version=FRAMEWORK_VERSION,
297+
train_instance_count=TRAIN_INSTANCE_COUNT,
298+
train_instance_type=TRAIN_INSTANCE_TYPE,
299+
sagemaker_session=sagemaker_session)
300+
tuner.estimator = mxnet
301+
302+
tags = [{'Name': 'some-tag-without-a-value'}]
303+
tuner.tags = tags
304+
305+
hyperparameter_ranges = {'num_components': IntegerParameter(2, 4),
306+
'algorithm_mode': CategoricalParameter(['regular', 'randomized'])}
307+
tuner._hyperparameter_ranges = hyperparameter_ranges
308+
309+
tuner.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode))
310+
311+
actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode']
312+
assert actual_input_mode == expected_input_mode
313+
314+
289315
def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, tuner):
290316
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
291317
base_job_name='pca', sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)