Skip to content

Commit 3fb5516

Browse files
Add integ tests for tuning jobs (aws#220)
1 parent 0cc5ccc commit 3fb5516

File tree

6 files changed

+268
-48
lines changed

6 files changed

+268
-48
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
========
77

88
* bug-fix: Unit Tests: Improve unit test runtime
9+
* bug-fix: Estimators: Fix attach for LDA
910

1011
1.4.1
1112
=====

src/sagemaker/amazon/lda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ def __init__(self, role, train_instance_type, num_topics,
7878
tol (float): Optional. Target error tolerance for the ALS phase of the algorithm.
7979
**kwargs: base class keyword argument values.
8080
"""
81-
8281
# this algorithm only supports single instance training
82+
if kwargs.pop('train_instance_count', 1) != 1:
83+
print('LDA only supports single instance training. Defaulting to 1 {}.'.format(train_instance_type))
84+
8385
super(LDA, self).__init__(role, 1, train_instance_type, **kwargs)
8486
self.num_topics = num_topics
8587
self.alpha0 = alpha0

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import boto3
1818
import pytest
19+
from botocore.config import Config
1920

2021
from sagemaker import Session
2122
from sagemaker.local import LocalSession
@@ -32,7 +33,7 @@ def pytest_addoption(parser):
3233
@pytest.fixture(scope='session')
3334
def sagemaker_client_config(request):
3435
config = request.config.getoption('--sagemaker-client-config')
35-
return json.loads(config) if config else None
36+
return json.loads(config) if config else dict()
3637

3738

3839
@pytest.fixture(scope='session')
@@ -50,6 +51,7 @@ def boto_config(request):
5051
@pytest.fixture(scope='session')
5152
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_config):
5253
boto_session = boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
54+
sagemaker_client_config.setdefault('config', Config(retries=dict(max_attempts=10)))
5355
sagemaker_client = boto_session.client('sagemaker', **sagemaker_client_config) if sagemaker_client_config else None
5456
runtime_client = (boto_session.client('sagemaker-runtime', **sagemaker_runtime_config) if sagemaker_runtime_config
5557
else None)

tests/data/chainer_mnist/mnist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb
7272
parser.add_argument('--epochs', type=int, default=20)
7373
parser.add_argument('--frequency', type=int, default=20)
7474
parser.add_argument('--batch-size', type=int, default=100)
75+
parser.add_argument('--alpha', type=float, default=0.001)
7576
parser.add_argument('--model-dir', type=str, default=env.model_dir)
7677

7778
parser.add_argument('--train', type=str, default=env.channel_input_dirs['train'])
@@ -103,7 +104,7 @@ def _preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb
103104
chainer.cuda.get_device_from_id(0).use()
104105

105106
# Setup an optimizer
106-
optimizer = chainer.optimizers.Adam()
107+
optimizer = chainer.optimizers.Adam(alpha=args.alpha)
107108
optimizer.setup(model)
108109

109110
# Load the MNIST dataset

tests/data/iris/iris-dnn-classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818

1919

2020
def estimator_fn(run_config, hyperparameters):
21-
input_tensor_name = hyperparameters['input_tensor_name']
21+
input_tensor_name = hyperparameters.get('input_tensor_name', 'inputs')
22+
learning_rate = hyperparameters.get('learning_rate', 0.05)
2223
feature_columns = [tf.feature_column.numeric_column(input_tensor_name, shape=[4])]
2324
return tf.estimator.DNNClassifier(feature_columns=feature_columns,
2425
hidden_units=[10, 20, 10],
26+
optimizer=tf.train.AdagradOptimizer(learning_rate=learning_rate),
2527
n_classes=3,
2628
config=run_config)
2729

0 commit comments

Comments
 (0)