Skip to content

Commit e068276

Browse files
authored
Merge branch 'master' into master
2 parents 7c80d16 + 8a3dea2 commit e068276

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,10 @@ you can specify these as keyword arguments.
10911091
other training source code dependencies aside from the entry point
10921092
file. Structure within this directory will be preserved when training
10931093
on SageMaker.
1094+
- ``requirements_file (str)`` Path to a ``requirements.txt`` file. The path should
1095+
be within and relative to ``source_dir``. This is a file containing a list of items to be
1096+
installed using pip install. Details on the format can be found in the
1097+
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
10941098
- ``hyperparameters (dict[str,ANY])`` Hyperparameters that will be used for training.
10951099
Will be made accessible as a dict[] to the training code on
10961100
SageMaker. Some hyperparameters will be interpreted by TensorFlow and can be use to

src/sagemaker/tensorflow/estimator.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ class TensorFlow(Framework):
108108

109109
__framework_name__ = 'tensorflow'
110110

111-
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version="py2",
112-
framework_version=TF_VERSION, **kwargs):
111+
def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=None, py_version='py2',
112+
framework_version=TF_VERSION, requirements_file='', **kwargs):
113113
"""Initialize an ``TensorFlow`` estimator.
114114
Args:
115115
training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -120,6 +120,9 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
120120
py_version (str): Python version you want to use for executing your model training code (default: 'py2').
121121
framework_version (str): TensorFlow version you want to use for executing your model training code.
122122
List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
123+
requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
124+
relative to ``source_dir``. Details on the format can be found in the
125+
`Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
123126
**kwargs: Additional kwargs passed to the Framework constructor.
124127
"""
125128
super(TensorFlow, self).__init__(**kwargs)
@@ -129,6 +132,22 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
129132
self.training_steps = training_steps
130133
self.evaluation_steps = evaluation_steps
131134

135+
self._validate_requirements_file(requirements_file)
136+
self.requirements_file = requirements_file
137+
138+
def _validate_requirements_file(self, requirements_file):
139+
if not requirements_file:
140+
return
141+
142+
if not self.source_dir:
143+
raise ValueError('Must specify source_dir along with a requirements file.')
144+
145+
if os.path.isabs(requirements_file):
146+
raise ValueError('Requirements file {} is not a path relative to source_dir.'.format(requirements_file))
147+
148+
if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
149+
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
150+
132151
def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_locally=False):
133152
"""Train a model using the input training dataset.
134153
@@ -228,11 +247,13 @@ def create_model(self, model_server_workers=None):
228247
sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
229248
See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
230249
"""
250+
env = {'SAGEMAKER_REQUIREMENTS': self.requirements_file}
231251
return TensorFlowModel(self.model_data, self.role, self.entry_point, source_dir=self.source_dir,
232-
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, name=self._current_job_name,
233-
container_log_level=self.container_log_level, code_location=self.code_location,
234-
py_version=self.py_version, framework_version=self.framework_version,
235-
model_server_workers=model_server_workers, sagemaker_session=self.sagemaker_session)
252+
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics, env=env,
253+
name=self._current_job_name, container_log_level=self.container_log_level,
254+
code_location=self.code_location, py_version=self.py_version,
255+
framework_version=self.framework_version, model_server_workers=model_server_workers,
256+
sagemaker_session=self.sagemaker_session)
236257

237258
def hyperparameters(self):
238259
"""Return hyperparameters used by your custom TensorFlow code during model training."""
@@ -243,7 +264,8 @@ def hyperparameters(self):
243264

244265
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
245266
'training_steps': self.training_steps,
246-
'evaluation_steps': self.evaluation_steps}
267+
'evaluation_steps': self.evaluation_steps,
268+
'sagemaker_requirements': self.requirements_file}
247269

248270
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
249271
return hyperparameters

tests/data/dummy_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fake-requirement-for-unit-tests==1.0.0

tests/unit/test_tf_estimator.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
import logging
14-
1513
import json
14+
import logging
1615
import os
16+
1717
import pytest
1818
from mock import Mock, patch
19+
20+
from sagemaker.fw_utils import create_image_uri
1921
from sagemaker.model import MODEL_SERVER_WORKERS_PARAM_NAME
2022
from sagemaker.session import s3_input
21-
from sagemaker.tensorflow import TensorFlow
22-
from sagemaker.tensorflow import defaults
23-
from sagemaker.fw_utils import create_image_uri
24-
from sagemaker.tensorflow import TensorFlowPredictor, TensorFlowModel
23+
from sagemaker.tensorflow import defaults, TensorFlow, TensorFlowPredictor, TensorFlowModel
2524

2625
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
27-
SCRIPT_PATH = os.path.join(DATA_DIR, 'dummy_script.py')
26+
SCRIPT_FILE = 'dummy_script.py'
27+
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_FILE)
28+
REQUIREMENTS_FILE = 'dummy_requirements.txt'
2829
TIMESTAMP = '2017-11-06-14:14:15.673'
2930
TIME = 1510006209.073025
3031
BUCKET_NAME = 'mybucket'
@@ -85,6 +86,7 @@ def _create_train_job(tf_version):
8586
'training_steps': '1000',
8687
'evaluation_steps': '10',
8788
'sagemaker_program': json.dumps('dummy_script.py'),
89+
'sagemaker_requirements': '"{}"'.format(REQUIREMENTS_FILE),
8890
'sagemaker_submit_directory': json.dumps('s3://{}/{}/source/sourcedir.tar.gz'.format(
8991
BUCKET_NAME, JOB_NAME)),
9092
'sagemaker_enable_cloudwatch_metrics': 'false',
@@ -100,10 +102,10 @@ def _create_train_job(tf_version):
100102

101103
def _build_tf(sagemaker_session, framework_version=defaults.TF_VERSION, train_instance_type=None,
102104
checkpoint_path=None, enable_cloudwatch_metrics=False, base_job_name=None,
103-
training_steps=None, evalutation_steps=None, **kwargs):
105+
training_steps=None, evaluation_steps=None, **kwargs):
104106
return TensorFlow(entry_point=SCRIPT_PATH,
105107
training_steps=training_steps,
106-
evaluation_steps=evalutation_steps,
108+
evaluation_steps=evaluation_steps,
107109
framework_version=framework_version,
108110
role=ROLE,
109111
sagemaker_session=sagemaker_session,
@@ -158,6 +160,20 @@ def test_tf_deploy_model_server_workers_unset(sagemaker_session):
158160
assert MODEL_SERVER_WORKERS_PARAM_NAME.upper() not in sagemaker_session.method_calls[3][1][2]['Environment']
159161

160162

163+
def test_tf_invalid_requirements_path(sagemaker_session):
164+
requirements_file = '/foo/bar/requirements.txt'
165+
with pytest.raises(ValueError) as e:
166+
_build_tf(sagemaker_session, requirements_file=requirements_file, source_dir=DATA_DIR)
167+
assert 'Requirements file {} is not a path relative to source_dir.'.format(requirements_file) in str(e.value)
168+
169+
170+
def test_tf_nonexistent_requirements_path(sagemaker_session):
171+
requirements_file = 'nonexistent_requirements.txt'
172+
with pytest.raises(ValueError) as e:
173+
_build_tf(sagemaker_session, requirements_file=requirements_file, source_dir=DATA_DIR)
174+
assert 'Requirements file {} does not exist.'.format(requirements_file) in str(e.value)
175+
176+
161177
def test_create_model(sagemaker_session, tf_version):
162178
container_log_level = '"logging.INFO"'
163179
source_dir = 's3://mybucket/source'
@@ -186,9 +202,9 @@ def test_create_model(sagemaker_session, tf_version):
186202
@patch('time.strftime', return_value=TIMESTAMP)
187203
@patch('time.time', return_value=TIME)
188204
def test_tf(time, strftime, sagemaker_session, tf_version):
189-
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
190-
training_steps=1000, evaluation_steps=10, train_instance_count=INSTANCE_COUNT,
191-
train_instance_type=INSTANCE_TYPE, framework_version=tf_version)
205+
tf = TensorFlow(entry_point=SCRIPT_FILE, role=ROLE, sagemaker_session=sagemaker_session, training_steps=1000,
206+
evaluation_steps=10, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
207+
framework_version=tf_version, requirements_file=REQUIREMENTS_FILE, source_dir=DATA_DIR)
192208

193209
inputs = 's3://mybucket/train'
194210

@@ -210,6 +226,7 @@ def test_tf(time, strftime, sagemaker_session, tf_version):
210226
assert {'Environment':
211227
{'SAGEMAKER_SUBMIT_DIRECTORY': 's3://{}/{}/sourcedir.tar.gz'.format(BUCKET_NAME, JOB_NAME),
212228
'SAGEMAKER_PROGRAM': 'dummy_script.py',
229+
'SAGEMAKER_REQUIREMENTS': 'dummy_requirements.txt',
213230
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
214231
'SAGEMAKER_REGION': 'us-west-2',
215232
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
@@ -315,7 +332,7 @@ def test_tf_training_and_evaluation_steps_not_set(sagemaker_session):
315332
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
316333
output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name)
317334

318-
tf = _build_tf(sagemaker_session, training_steps=None, evalutation_steps=None, output_path=output_path)
335+
tf = _build_tf(sagemaker_session, training_steps=None, evaluation_steps=None, output_path=output_path)
319336
tf.fit(inputs=s3_input('s3://mybucket/train'))
320337
assert tf.hyperparameters()['training_steps'] == 'null'
321338
assert tf.hyperparameters()['evaluation_steps'] == 'null'
@@ -325,7 +342,7 @@ def test_tf_training_and_evaluation_steps(sagemaker_session):
325342
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
326343
output_path = "s3://{}/output/{}/".format(sagemaker_session.default_bucket(), job_name)
327344

328-
tf = _build_tf(sagemaker_session, training_steps=123, evalutation_steps=456, output_path=output_path)
345+
tf = _build_tf(sagemaker_session, training_steps=123, evaluation_steps=456, output_path=output_path)
329346
tf.fit(inputs=s3_input('s3://mybucket/train'))
330347
assert tf.hyperparameters()['training_steps'] == '123'
331348
assert tf.hyperparameters()['evaluation_steps'] == '456'

0 commit comments

Comments
 (0)