|
15 | 15 | import pytest
|
16 | 16 | from mock import Mock
|
17 | 17 |
|
| 18 | +from sagemaker import RealTimePredictor |
18 | 19 | from sagemaker.amazon.pca import PCA
|
19 | 20 | from sagemaker.amazon.amazon_estimator import RecordSet
|
20 | 21 | from sagemaker.estimator import Estimator
|
21 | 22 | from sagemaker.tuner import _ParameterRange, ContinuousParameter, IntegerParameter, CategoricalParameter, \
|
22 | 23 | HyperparameterTuner, _TuningJob
|
23 | 24 | from sagemaker.mxnet import MXNet
|
| 25 | +MODEL_DATA = "s3://bucket/model.tar.gz" |
24 | 26 |
|
25 | 27 | JOB_NAME = 'tuning_job'
|
26 | 28 | REGION = 'us-west-2'
|
@@ -211,6 +213,68 @@ def test_best_tuning_job_no_best_job(tuner):
|
211 | 213 | assert 'Best training job not available for tuning job:' in str(e)
|
212 | 214 |
|
213 | 215 |
|
| 216 | +def test_deploy_default(tuner): |
| 217 | + returned_training_job_description = { |
| 218 | + 'AlgorithmSpecification': { |
| 219 | + 'TrainingInputMode': 'File', |
| 220 | + 'TrainingImage': IMAGE_NAME |
| 221 | + }, |
| 222 | + 'HyperParameters': { |
| 223 | + 'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"', |
| 224 | + 'checkpoint_path': '"s3://other/1508872349"', |
| 225 | + 'sagemaker_program': '"iris-dnn-classifier.py"', |
| 226 | + 'sagemaker_enable_cloudwatch_metrics': 'false', |
| 227 | + 'sagemaker_container_log_level': '"logging.INFO"', |
| 228 | + 'sagemaker_job_name': '"neo"', |
| 229 | + 'training_steps': '100', |
| 230 | + '_tuning_objective_metric': 'Validation-accuracy', |
| 231 | + }, |
| 232 | + |
| 233 | + 'RoleArn': ROLE, |
| 234 | + 'ResourceConfig': { |
| 235 | + 'VolumeSizeInGB': 30, |
| 236 | + 'InstanceCount': 1, |
| 237 | + 'InstanceType': 'ml.c4.xlarge' |
| 238 | + }, |
| 239 | + 'StoppingCondition': { |
| 240 | + 'MaxRuntimeInSeconds': 24 * 60 * 60 |
| 241 | + }, |
| 242 | + 'TrainingJobName': 'neo', |
| 243 | + 'TrainingJobStatus': 'Completed', |
| 244 | + 'OutputDataConfig': { |
| 245 | + 'KmsKeyId': '', |
| 246 | + 'S3OutputPath': 's3://place/output/neo' |
| 247 | + }, |
| 248 | + 'TrainingJobOutput': { |
| 249 | + 'S3TrainingJobOutput': 's3://here/output.tar.gz' |
| 250 | + }, |
| 251 | + 'ModelArtifacts': { |
| 252 | + 'S3ModelArtifacts': MODEL_DATA |
| 253 | + } |
| 254 | + } |
| 255 | + tuning_job_description = {'BestTrainingJob': {'TrainingJobName': JOB_NAME}} |
| 256 | + |
| 257 | + tuner.estimator.sagemaker_session.sagemaker_client.describe_training_job = \ |
| 258 | + Mock(name='describe_training_job', return_value=returned_training_job_description) |
| 259 | + tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( |
| 260 | + name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description) |
| 261 | + tuner.estimator.sagemaker_session.log_for_jobs = Mock(name='log_for_jobs') |
| 262 | + |
| 263 | + tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME) |
| 264 | + predictor = tuner.deploy(TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE) |
| 265 | + |
| 266 | + tuner.estimator.sagemaker_session.create_model.assert_called_once() |
| 267 | + args = tuner.estimator.sagemaker_session.create_model.call_args[0] |
| 268 | + assert args[0].startswith(IMAGE_NAME) |
| 269 | + assert args[1] == ROLE |
| 270 | + assert args[2]['Image'] == IMAGE_NAME |
| 271 | + assert args[2]['ModelDataUrl'] == MODEL_DATA |
| 272 | + |
| 273 | + assert isinstance(predictor, RealTimePredictor) |
| 274 | + assert predictor.endpoint.startswith(JOB_NAME) |
| 275 | + assert predictor.sagemaker_session == tuner.estimator.sagemaker_session |
| 276 | + |
| 277 | + |
214 | 278 | #################################################################################
|
215 | 279 | # _ParameterRange Tests
|
216 | 280 |
|
|
0 commit comments