Skip to content

Commit 507f2cd

Browse files
RodrigoAtAWSnadiaya
authored andcommitted
Add model parameters to Estimator, and bump library version to 1.13.0 (aws#450)
* Add incremental training model parameters to Estimator, and bump library version to 1.13.0 * Update README.rst Removed unnecessary comma from sentence. * Update estimator.py Removed duplicated line
1 parent a5595af commit 507f2cd

File tree

14 files changed

+401
-102
lines changed

14 files changed

+401
-102
lines changed

CHANGELOG.rst

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
CHANGELOG
33
=========
44

5+
1.13.0
6+
======
7+
8+
* feature: Estimator: add input mode to training channels
9+
* feature: Estimator: add model_uri and model_channel_name parameters
10+
511
1.12.0
612
======
713

README.rst

+60
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,66 @@ A few important notes:
263263
- Local Mode requires Docker Compose and `nvidia-docker2 <https://github.com/NVIDIA/nvidia-docker>`__ for ``local_gpu``.
264264
- Distributed training is not yet supported for ``local_gpu``.
265265
266+
Incremental Training
267+
~~~~~~~~~~~~~~~~~~~~
268+
269+
Incremental training allows you to bring a pre-trained model into a SageMaker training job and use it as a starting point for a new model.
270+
There are several situations where you might want to do this:
271+
272+
- You want to perform additional training on a model to improve its fit on your data set.
273+
- You want to import a pre-trained model and fit it to your data.
274+
- You want to resume a training job that you previously stopped.
275+
276+
To use incremental training with SageMaker algorithms, you need model artifacts compressed into a ``tar.gz`` file. These
277+
artifacts are passed to a training job via an input channel configured with the pre-defined settings Amazon SageMaker algorithms require.
278+
279+
To use model files with a SageMaker estimator, you can use the following parameters:
280+
281+
* ``model_uri``: points to the location of a model tarball, either in S3 or locally. Specifying a local path only works in local mode.
282+
* ``model_channel_name``: name of the channel SageMaker will use to download the tarball specified in ``model_uri``. Defaults to 'model'.
283+
284+
This is converted into an input channel with the specifications mentioned above once you call ``fit()`` on the predictor.
285+
In bring-your-own cases, ``model_channel_name`` can be overriden if you require to change the name of the channel while using
286+
the same settings.
287+
288+
If your bring-your-own case requires different settings, you can create your own ``s3_input`` object with the settings you require.
289+
290+
Here's an example of how to use incremental training:
291+
292+
.. code:: python
293+
# Configure an estimator
294+
estimator = sagemaker.estimator.Estimator(training_image,
295+
role,
296+
train_instance_count=1,
297+
train_instance_type='ml.p2.xlarge',
298+
train_volume_size=50,
299+
train_max_run=360000,
300+
input_mode='File',
301+
output_path=s3_output_location)
302+
303+
# Start a SageMaker training job and waits until completion.
304+
estimator.fit('s3://my_bucket/my_training_data/')
305+
306+
# Create a new estimator using the previous' model artifacts
307+
incr_estimator = sagemaker.estimator.Estimator(training_image,
308+
role,
309+
train_instance_count=1,
310+
train_instance_type='ml.p2.xlarge',
311+
train_volume_size=50,
312+
train_max_run=360000,
313+
input_mode='File',
314+
output_path=s3_output_location,
315+
model_uri=estimator.model_data)
316+
317+
# Start a SageMaker training job using the original model for incremental training
318+
incr_estimator.fit('s3://my_bucket/my_training_data/')
319+
320+
Currently, the following algorithms support incremental training:
321+
322+
- Image Classification
323+
- Object Detection
324+
- Semantics Segmentation
325+
266326
267327
MXNet SageMaker Estimators
268328
--------------------------

src/sagemaker/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@
3535
from sagemaker.session import s3_input # noqa: F401
3636
from sagemaker.session import get_execution_role # noqa: F401
3737

38-
__version__ = '1.12.0'
38+
__version__ = '1.13.0'

src/sagemaker/amazon/amazon_estimator.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,19 @@ def data_location(self, data_location):
7070
self._data_location = data_location
7171

7272
@classmethod
73-
def _prepare_init_params_from_job_description(cls, job_details):
73+
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
7474
"""Convert the job description to init params that can be handled by the class constructor
7575
7676
Args:
7777
job_details: the returned job details from a describe_training_job API call.
78+
model_channel_name (str): Name of the channel where pre-trained model data will be downloaded.
7879
7980
Returns:
8081
dictionary: The transformed init_params
8182
8283
"""
83-
init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(job_details)
84+
init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(
85+
job_details, model_channel_name)
8486

8587
# The hyperparam names may not be the same as the class attribute that holds them,
8688
# for instance: local_lloyd_init_method is called local_init_method. We need to map these

src/sagemaker/chainer/estimator.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,18 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
134134
vpc_config=self.get_vpc_config(vpc_config_override))
135135

136136
@classmethod
137-
def _prepare_init_params_from_job_description(cls, job_details):
137+
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
138138
"""Convert the job description to init params that can be handled by the class constructor
139139
140140
Args:
141141
job_details: the returned job details from a describe_training_job API call.
142+
model_channel_name (str): Name of the channel where pre-trained model data will be downloaded.
142143
143144
Returns:
144145
dictionary: The transformed init_params
145146
146147
"""
147-
init_params = super(Chainer, cls)._prepare_init_params_from_job_description(job_details)
148+
init_params = super(Chainer, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
148149

149150
for argument in [Chainer._use_mpi, Chainer._num_processes, Chainer._process_slots_per_host,
150151
Chainer._additional_mpi_options]:

0 commit comments

Comments
 (0)