diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 36c4a6aa11..ab38b15a27 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -569,6 +569,12 @@ def start_new(cls, estimator, inputs): train_args['tags'] = estimator.tags train_args['metric_definitions'] = estimator.metric_definitions + if isinstance(inputs, s3_input): + if 'InputMode' in inputs.config: + logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.' + .format(inputs.config['InputMode'])) + train_args['input_mode'] = inputs.config['InputMode'] + if estimator.enable_network_isolation(): train_args['enable_network_isolation'] = True diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index ac2dfc2476..12396e922b 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -15,6 +15,7 @@ import importlib import inspect import json +import logging from enum import Enum import sagemaker @@ -26,6 +27,7 @@ from sagemaker.parameter import (CategoricalParameter, ContinuousParameter, IntegerParameter, ParameterRange) from sagemaker.session import Session +from sagemaker.session import s3_input from sagemaker.utils import base_name_from_image, name_from_base, to_str AMAZON_ESTIMATOR_MODULE = 'sagemaker' @@ -640,6 +642,12 @@ def start_new(cls, tuner, inputs): tuner_args['warm_start_config'] = warm_start_config_req tuner_args['early_stopping_type'] = tuner.early_stopping_type + if isinstance(inputs, s3_input): + if 'InputMode' in inputs.config: + logging.debug('Selecting s3_input\'s input_mode ({}) for TrainingInputMode.' + .format(inputs.config['InputMode'])) + tuner_args['input_mode'] = inputs.config['InputMode'] + if isinstance(tuner.estimator, sagemaker.algorithm.AlgorithmEstimator): tuner_args['algorithm_arn'] = tuner.estimator.algorithm_arn else: diff --git a/tests/scripts/run-notebook-test.sh b/tests/scripts/run-notebook-test.sh index 83d7f5753d..be3c9f125f 100755 --- a/tests/scripts/run-notebook-test.sh +++ b/tests/scripts/run-notebook-test.sh @@ -10,9 +10,9 @@ aws s3 --region us-west-2 cp ./dist/sagemaker-*.tar.gz s3://sagemaker-python-sdk aws s3 cp s3://sagemaker-mead-cli/mead-nb-test.tar.gz mead-nb-test.tar.gz tar -xzf mead-nb-test.tar.gz git clone --depth 1 https://github.com/awslabs/amazon-sagemaker-examples.git -JAVA_HOME=$(get-java-home) +export JAVA_HOME=$(get-java-home) echo "set JAVA_HOME=$JAVA_HOME" -SAGEMAKER_ROLE_ARN=$(get-sagemaker-role-arn) +export SAGEMAKER_ROLE_ARN=$(get-sagemaker-role-arn) echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN" ./runtime/bin/mead-run-nb-test \ --instance-type ml.c4.8xlarge \ diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 548ff15cc9..883ad0c0d2 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -315,6 +315,17 @@ def test_augmented_manifest(sagemaker_session): assert s3_data_source['AttributeNames'] == ['foo', 'bar'] +def test_s3_input_mode(sagemaker_session): + expected_input_mode = 'Pipe' + fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True) + fw.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode)) + + actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode'] + assert actual_input_mode == expected_input_mode + + def test_shuffle_config(sagemaker_session): fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 58e19e8d32..bb57d480af 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -28,6 +28,7 @@ from sagemaker.tuner import (_TuningJob, create_identical_dataset_and_algorithm_tuner, create_transfer_learning_tuner, HyperparameterTuner, WarmStartConfig, WarmStartTypes) +from sagemaker.session import s3_input DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data') MODEL_DATA = "s3://bucket/model.tar.gz" @@ -286,6 +287,31 @@ def test_fit_mxnet_with_vpc_config(sagemaker_session, tuner): assert tune_kwargs['vpc_config'] == {'Subnets': subnets, 'SecurityGroupIds': security_group_ids} +def test_s3_input_mode(sagemaker_session, tuner): + expected_input_mode = 'Pipe' + + script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py') + mxnet = MXNet(entry_point=script_path, + role=ROLE, + framework_version=FRAMEWORK_VERSION, + train_instance_count=TRAIN_INSTANCE_COUNT, + train_instance_type=TRAIN_INSTANCE_TYPE, + sagemaker_session=sagemaker_session) + tuner.estimator = mxnet + + tags = [{'Name': 'some-tag-without-a-value'}] + tuner.tags = tags + + hyperparameter_ranges = {'num_components': IntegerParameter(2, 4), + 'algorithm_mode': CategoricalParameter(['regular', 'randomized'])} + tuner._hyperparameter_ranges = hyperparameter_ranges + + tuner.fit(inputs=s3_input('s3://mybucket/train_manifest', input_mode=expected_input_mode)) + + actual_input_mode = sagemaker_session.method_calls[1][2]['input_mode'] + assert actual_input_mode == expected_input_mode + + def test_fit_pca_with_inter_container_traffic_encryption_flag(sagemaker_session, tuner): pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS, base_job_name='pca', sagemaker_session=sagemaker_session,