Skip to content

Commit 1711e33

Browse files
authored
Update latest version of PyTorch to 1.0 (aws#547)
1 parent 4b7db48 commit 1711e33

File tree

5 files changed

+17
-18
lines changed

5 files changed

+17
-18
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88
* bug-fix: Local Mode: Allow support for SSH in local mode
99
* bug-fix: Append retry id to default Airflow job name to avoid name collisions in retry
1010
* bug-fix: Local Mode: No longer requires s3 permissions to run local entry point file
11+
* feature: Estimators: add support for PyTorch 1.0.0
1112
* bug-fix: Local Mode: Move dependency on sagemaker_s3_output from rl.estimator to model
1213
* doc-fix: Fix quotes in estimator.py and model.py
1314

README.rst

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,10 @@ PyTorch SageMaker Estimators
406406
407407
With PyTorch SageMaker ``Estimators``, you can train and host PyTorch models on Amazon SageMaker.
408408
409-
Supported versions of PyTorch: ``0.4.0``, ``1.0.0.dev`` ("Preview").
409+
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``.
410410
411411
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
412412
413-
You can try the "Preview" version of PyTorch by specifying ``'1.0.0.dev'`` for ``framework_version`` when creating your PyTorch estimator.
414-
This will ensure you're using the latest version of ``torch-nightly``.
415-
416413
For more information about PyTorch, see https://github.com/pytorch/pytorch.
417414
418415
For more information about PyTorch SageMaker ``Estimators``, see `PyTorch SageMaker Estimators and Models`_.

src/sagemaker/pytorch/README.rst

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ SageMaker PyTorch Estimators and Models
44

55
With PyTorch Estimators and Models, you can train and host PyTorch models on Amazon SageMaker.
66

7-
Supported versions of PyTorch: ``0.4.0``, ``1.0.0.dev`` ("Preview").
7+
Supported versions of PyTorch: ``0.4.0``, ``1.0.0``.
88

9-
You can try the "Preview" version of PyTorch by specifying ``1.0.0.dev`` for ``framework_version`` when creating your PyTorch estimator.
10-
This will ensure you're using the latest version of ``torch-nightly``.
9+
We recommend that you use the latest supported version, because that's where we focus most of our development efforts.
1110

1211
You can visit the PyTorch repository at https://github.com/pytorch/pytorch.
1312

@@ -49,7 +48,7 @@ You can then setup a ``PyTorch`` Estimator with keyword arguments to point to th
4948
role='SageMakerRole',
5049
train_instance_type='ml.p3.2xlarge',
5150
train_instance_count=1,
52-
framework_version='0.4.0')
51+
framework_version='1.0.0')
5352
5453
After that, you simply tell the estimator to start a training job and provide an S3 URL
5554
that is the path to your training data within Amazon S3:
@@ -137,7 +136,7 @@ directories ('train' and 'test').
137136
pytorch_estimator = PyTorch('pytorch-train.py',
138137
train_instance_type='ml.p3.2xlarge',
139138
train_instance_count=1,
140-
framework_version='0.4.0',
139+
framework_version='1.0.0',
141140
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
142141
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
143142
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -338,7 +337,7 @@ operation.
338337
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
339338
train_instance_type='ml.p3.2xlarge',
340339
train_instance_count=1,
341-
framework_version='0.4.0')
340+
framework_version='1.0.0')
342341
pytorch_estimator.fit('s3://my_bucket/my_training_data/')
343342
344343
# Deploy my estimator to a SageMaker Endpoint and get a Predictor
@@ -675,21 +674,21 @@ When training and deploying training scripts, SageMaker runs your Python script
675674
libraries installed. When creating the Estimator and calling deploy to create the SageMaker Endpoint, you can control
676675
the environment your script runs in.
677676

678-
SageMaker runs PyTorch Estimator scripts in either Python 2.7 or Python 3.5. You can select the Python version by
677+
SageMaker runs PyTorch Estimator scripts in either Python 2 or Python 3. You can select the Python version by
679678
passing a ``py_version`` keyword arg to the PyTorch Estimator constructor. Setting this to `py3` (the default) will cause your
680679
training script to be run on Python 3.5. Setting this to `py2` will cause your training script to be run on Python 2.7
681680
This Python version applies to both the Training Job, created by fit, and the Endpoint, created by deploy.
682681

683682
The PyTorch Docker images have the following dependencies installed:
684683

685684
+-----------------------------+---------------+-------------------+
686-
| Dependencies | pytorch 0.4.0 | pytorch 1.0.0.dev |
685+
| Dependencies | pytorch 0.4.0 | pytorch 1.0.0 |
687686
+-----------------------------+---------------+-------------------+
688687
| boto3 | >=1.7.35 | >=1.9.11 |
689688
+-----------------------------+---------------+-------------------+
690689
| botocore | >=1.10.35 | >=1.12.11 |
691690
+-----------------------------+---------------+-------------------+
692-
| CUDA (GPU image only) | 9.0 | 9.2 |
691+
| CUDA (GPU image only) | 9.0 | 9.0 |
693692
+-----------------------------+---------------+-------------------+
694693
| numpy | >=1.14.3 | >=1.15.2 |
695694
+-----------------------------+---------------+-------------------+
@@ -711,11 +710,11 @@ The PyTorch Docker images have the following dependencies installed:
711710
+-----------------------------+---------------+-------------------+
712711
| six | >=1.11.0 | >=1.11.0 |
713712
+-----------------------------+---------------+-------------------+
714-
| torch (torch-nightly) | 0.4.0 | 1.0.0.dev |
713+
| torch | 0.4.0 | 1.0.0 |
715714
+-----------------------------+---------------+-------------------+
716715
| torchvision | 0.2.1 | 0.2.1 |
717716
+-----------------------------+---------------+-------------------+
718-
| Python | 2.7 or 3.5 | 2.7 or 3.5 |
717+
| Python | 2.7 or 3.5 | 2.7 or 3.6 |
719718
+-----------------------------+---------------+-------------------+
720719

721720
The Docker images extend Ubuntu 16.04.

src/sagemaker/pytorch/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class PyTorch(Framework):
2929

3030
__framework_name__ = "pytorch"
3131

32+
LATEST_VERSION = '1.0'
33+
3234
def __init__(self, entry_point, source_dir=None, hyperparameters=None, py_version=PYTHON_VERSION,
3335
framework_version=None, image_name=None, **kwargs):
3436
"""

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.chainer import Chainer
2424
from sagemaker.local import LocalSession
2525
from sagemaker.mxnet import MXNet
26-
from sagemaker.pytorch.defaults import PYTORCH_VERSION
26+
from sagemaker.pytorch import PyTorch
2727
from sagemaker.rl import RLEstimator
2828
from sagemaker.sklearn.defaults import SKLEARN_VERSION
2929
from sagemaker.tensorflow.defaults import TF_VERSION
@@ -37,7 +37,7 @@ def pytest_addoption(parser):
3737
parser.addoption('--boto-config', action='store', default=None)
3838
parser.addoption('--chainer-full-version', action='store', default=Chainer.LATEST_VERSION)
3939
parser.addoption('--mxnet-full-version', action='store', default=MXNet.LATEST_VERSION)
40-
parser.addoption('--pytorch-full-version', action='store', default=PYTORCH_VERSION)
40+
parser.addoption('--pytorch-full-version', action='store', default=PyTorch.LATEST_VERSION)
4141
parser.addoption('--rl-coach-full-version', action='store',
4242
default=RLEstimator.COACH_LATEST_VERSION)
4343
parser.addoption('--rl-ray-full-version', action='store',
@@ -114,7 +114,7 @@ def ei_mxnet_version(request):
114114
return request.param
115115

116116

117-
@pytest.fixture(scope='module', params=['0.4', '0.4.0'])
117+
@pytest.fixture(scope='module', params=['0.4', '0.4.0', '1.0', '1.0.0'])
118118
def pytorch_version(request):
119119
return request.param
120120

0 commit comments

Comments
 (0)