Skip to content

Commit abbdbd5

Browse files
Add deploy (aws#36)
1 parent 8c68620 commit abbdbd5

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

src/sagemaker/tuner.py

+28
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,34 @@ def fit(self, inputs, job_name=None, **kwargs):
119119
self.prepare_for_training()
120120
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
121121

122+
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
123+
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
124+
``sagemaker.RealTimePredictor``
125+
object.
126+
127+
More information:
128+
http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html
129+
130+
Args:
131+
initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for
132+
prediction.
133+
instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction,
134+
for example, 'ml.c4.xlarge'.
135+
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified,
136+
the name of the training job is used.
137+
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
138+
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
139+
For more, see the implementation docs.
140+
141+
Returns:
142+
sagemaker.predictor.RealTimePredictor: A predictor that provides a ``predict()`` method,
143+
which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.
144+
"""
145+
endpoint_name = endpoint_name or self.best_training_job()
146+
best_estimator = self.estimator.attach(self.best_training_job(),
147+
sagemaker_session=self.estimator.sagemaker_session)
148+
return best_estimator.deploy(initial_instance_count, instance_type, endpoint_name=endpoint_name, **kwargs)
149+
122150
def stop_tuning_job(self):
123151
"""Stop latest running tuning job.
124152
"""

tests/unit/test_tuner.py

+64
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
import pytest
1616
from mock import Mock
1717

18+
from sagemaker import RealTimePredictor
1819
from sagemaker.amazon.pca import PCA
1920
from sagemaker.amazon.amazon_estimator import RecordSet
2021
from sagemaker.estimator import Estimator
2122
from sagemaker.tuner import _ParameterRange, ContinuousParameter, IntegerParameter, CategoricalParameter, \
2223
HyperparameterTuner, _TuningJob
2324
from sagemaker.mxnet import MXNet
25+
MODEL_DATA = "s3://bucket/model.tar.gz"
2426

2527
JOB_NAME = 'tuning_job'
2628
REGION = 'us-west-2'
@@ -211,6 +213,68 @@ def test_best_tuning_job_no_best_job(tuner):
211213
assert 'Best training job not available for tuning job:' in str(e)
212214

213215

216+
def test_deploy_default(tuner):
217+
returned_training_job_description = {
218+
'AlgorithmSpecification': {
219+
'TrainingInputMode': 'File',
220+
'TrainingImage': IMAGE_NAME
221+
},
222+
'HyperParameters': {
223+
'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"',
224+
'checkpoint_path': '"s3://other/1508872349"',
225+
'sagemaker_program': '"iris-dnn-classifier.py"',
226+
'sagemaker_enable_cloudwatch_metrics': 'false',
227+
'sagemaker_container_log_level': '"logging.INFO"',
228+
'sagemaker_job_name': '"neo"',
229+
'training_steps': '100',
230+
'_tuning_objective_metric': 'Validation-accuracy',
231+
},
232+
233+
'RoleArn': ROLE,
234+
'ResourceConfig': {
235+
'VolumeSizeInGB': 30,
236+
'InstanceCount': 1,
237+
'InstanceType': 'ml.c4.xlarge'
238+
},
239+
'StoppingCondition': {
240+
'MaxRuntimeInSeconds': 24 * 60 * 60
241+
},
242+
'TrainingJobName': 'neo',
243+
'TrainingJobStatus': 'Completed',
244+
'OutputDataConfig': {
245+
'KmsKeyId': '',
246+
'S3OutputPath': 's3://place/output/neo'
247+
},
248+
'TrainingJobOutput': {
249+
'S3TrainingJobOutput': 's3://here/output.tar.gz'
250+
},
251+
'ModelArtifacts': {
252+
'S3ModelArtifacts': MODEL_DATA
253+
}
254+
}
255+
tuning_job_description = {'BestTrainingJob': {'TrainingJobName': JOB_NAME}}
256+
257+
tuner.estimator.sagemaker_session.sagemaker_client.describe_training_job = \
258+
Mock(name='describe_training_job', return_value=returned_training_job_description)
259+
tuner.estimator.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
260+
name='describe_hyper_parameter_tuning_job', return_value=tuning_job_description)
261+
tuner.estimator.sagemaker_session.log_for_jobs = Mock(name='log_for_jobs')
262+
263+
tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME)
264+
predictor = tuner.deploy(TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE)
265+
266+
tuner.estimator.sagemaker_session.create_model.assert_called_once()
267+
args = tuner.estimator.sagemaker_session.create_model.call_args[0]
268+
assert args[0].startswith(IMAGE_NAME)
269+
assert args[1] == ROLE
270+
assert args[2]['Image'] == IMAGE_NAME
271+
assert args[2]['ModelDataUrl'] == MODEL_DATA
272+
273+
assert isinstance(predictor, RealTimePredictor)
274+
assert predictor.endpoint.startswith(JOB_NAME)
275+
assert predictor.sagemaker_session == tuner.estimator.sagemaker_session
276+
277+
214278
#################################################################################
215279
# _ParameterRange Tests
216280

0 commit comments

Comments
 (0)