10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- import logging
14
-
15
13
import json
14
+ import logging
16
15
import os
16
+
17
17
import pytest
18
18
from mock import Mock , patch
19
+
20
+ from sagemaker .fw_utils import create_image_uri
19
21
from sagemaker .model import MODEL_SERVER_WORKERS_PARAM_NAME
20
22
from sagemaker .session import s3_input
21
- from sagemaker .tensorflow import TensorFlow
22
- from sagemaker .tensorflow import defaults
23
- from sagemaker .fw_utils import create_image_uri
24
- from sagemaker .tensorflow import TensorFlowPredictor , TensorFlowModel
23
+ from sagemaker .tensorflow import defaults , TensorFlow , TensorFlowPredictor , TensorFlowModel
25
24
26
25
DATA_DIR = os .path .join (os .path .dirname (__file__ ), '..' , 'data' )
27
- SCRIPT_PATH = os .path .join (DATA_DIR , 'dummy_script.py' )
26
+ SCRIPT_FILE = 'dummy_script.py'
27
+ SCRIPT_PATH = os .path .join (DATA_DIR , SCRIPT_FILE )
28
+ REQUIREMENTS_FILE = 'dummy_requirements.txt'
28
29
TIMESTAMP = '2017-11-06-14:14:15.673'
29
30
TIME = 1510006209.073025
30
31
BUCKET_NAME = 'mybucket'
@@ -85,6 +86,7 @@ def _create_train_job(tf_version):
85
86
'training_steps' : '1000' ,
86
87
'evaluation_steps' : '10' ,
87
88
'sagemaker_program' : json .dumps ('dummy_script.py' ),
89
+ 'sagemaker_requirements' : '"{}"' .format (REQUIREMENTS_FILE ),
88
90
'sagemaker_submit_directory' : json .dumps ('s3://{}/{}/source/sourcedir.tar.gz' .format (
89
91
BUCKET_NAME , JOB_NAME )),
90
92
'sagemaker_enable_cloudwatch_metrics' : 'false' ,
@@ -100,10 +102,10 @@ def _create_train_job(tf_version):
100
102
101
103
def _build_tf (sagemaker_session , framework_version = defaults .TF_VERSION , train_instance_type = None ,
102
104
checkpoint_path = None , enable_cloudwatch_metrics = False , base_job_name = None ,
103
- training_steps = None , evalutation_steps = None , ** kwargs ):
105
+ training_steps = None , evaluation_steps = None , ** kwargs ):
104
106
return TensorFlow (entry_point = SCRIPT_PATH ,
105
107
training_steps = training_steps ,
106
- evaluation_steps = evalutation_steps ,
108
+ evaluation_steps = evaluation_steps ,
107
109
framework_version = framework_version ,
108
110
role = ROLE ,
109
111
sagemaker_session = sagemaker_session ,
@@ -158,6 +160,20 @@ def test_tf_deploy_model_server_workers_unset(sagemaker_session):
158
160
assert MODEL_SERVER_WORKERS_PARAM_NAME .upper () not in sagemaker_session .method_calls [3 ][1 ][2 ]['Environment' ]
159
161
160
162
163
+ def test_tf_invalid_requirements_path (sagemaker_session ):
164
+ requirements_file = '/foo/bar/requirements.txt'
165
+ with pytest .raises (ValueError ) as e :
166
+ _build_tf (sagemaker_session , requirements_file = requirements_file , source_dir = DATA_DIR )
167
+ assert 'Requirements file {} is not a path relative to source_dir.' .format (requirements_file ) in str (e .value )
168
+
169
+
170
+ def test_tf_nonexistent_requirements_path (sagemaker_session ):
171
+ requirements_file = 'nonexistent_requirements.txt'
172
+ with pytest .raises (ValueError ) as e :
173
+ _build_tf (sagemaker_session , requirements_file = requirements_file , source_dir = DATA_DIR )
174
+ assert 'Requirements file {} does not exist.' .format (requirements_file ) in str (e .value )
175
+
176
+
161
177
def test_create_model (sagemaker_session , tf_version ):
162
178
container_log_level = '"logging.INFO"'
163
179
source_dir = 's3://mybucket/source'
@@ -186,9 +202,9 @@ def test_create_model(sagemaker_session, tf_version):
186
202
@patch ('time.strftime' , return_value = TIMESTAMP )
187
203
@patch ('time.time' , return_value = TIME )
188
204
def test_tf (time , strftime , sagemaker_session , tf_version ):
189
- tf = TensorFlow (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
190
- training_steps = 1000 , evaluation_steps = 10 , train_instance_count = INSTANCE_COUNT ,
191
- train_instance_type = INSTANCE_TYPE , framework_version = tf_version )
205
+ tf = TensorFlow (entry_point = SCRIPT_FILE , role = ROLE , sagemaker_session = sagemaker_session , training_steps = 1000 ,
206
+ evaluation_steps = 10 , train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
207
+ framework_version = tf_version , requirements_file = REQUIREMENTS_FILE , source_dir = DATA_DIR )
192
208
193
209
inputs = 's3://mybucket/train'
194
210
@@ -210,6 +226,7 @@ def test_tf(time, strftime, sagemaker_session, tf_version):
210
226
assert {'Environment' :
211
227
{'SAGEMAKER_SUBMIT_DIRECTORY' : 's3://{}/{}/sourcedir.tar.gz' .format (BUCKET_NAME , JOB_NAME ),
212
228
'SAGEMAKER_PROGRAM' : 'dummy_script.py' ,
229
+ 'SAGEMAKER_REQUIREMENTS' : 'dummy_requirements.txt' ,
213
230
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS' : 'false' ,
214
231
'SAGEMAKER_REGION' : 'us-west-2' ,
215
232
'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20'
@@ -315,7 +332,7 @@ def test_tf_training_and_evaluation_steps_not_set(sagemaker_session):
315
332
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
316
333
output_path = "s3://{}/output/{}/" .format (sagemaker_session .default_bucket (), job_name )
317
334
318
- tf = _build_tf (sagemaker_session , training_steps = None , evalutation_steps = None , output_path = output_path )
335
+ tf = _build_tf (sagemaker_session , training_steps = None , evaluation_steps = None , output_path = output_path )
319
336
tf .fit (inputs = s3_input ('s3://mybucket/train' ))
320
337
assert tf .hyperparameters ()['training_steps' ] == 'null'
321
338
assert tf .hyperparameters ()['evaluation_steps' ] == 'null'
@@ -325,7 +342,7 @@ def test_tf_training_and_evaluation_steps(sagemaker_session):
325
342
job_name = "sagemaker-tensorflow-py2-gpu-2017-10-24-14-12-09"
326
343
output_path = "s3://{}/output/{}/" .format (sagemaker_session .default_bucket (), job_name )
327
344
328
- tf = _build_tf (sagemaker_session , training_steps = 123 , evalutation_steps = 456 , output_path = output_path )
345
+ tf = _build_tf (sagemaker_session , training_steps = 123 , evaluation_steps = 456 , output_path = output_path )
329
346
tf .fit (inputs = s3_input ('s3://mybucket/train' ))
330
347
assert tf .hyperparameters ()['training_steps' ] == '123'
331
348
assert tf .hyperparameters ()['evaluation_steps' ] == '456'
0 commit comments