Skip to content

Commit 7a82fb5

Browse files
authored
Update coach version for to 0.11.1 tensorflow. (aws#684)
1 parent 9f359c9 commit 7a82fb5

File tree

6 files changed

+27
-15
lines changed

6 files changed

+27
-15
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ CHANGELOG
1515
* bug-fix: pass accelerator_type in ``deploy`` for REST API TFS ``Model``
1616
* doc-fix: move content from tf/README.rst into sphynx project
1717
* doc-fix: Improve new developer experience in README
18+
* feature: Add support for Coach 0.11.1 for Tensorflow
1819

1920
1.18.3.post1
2021
============

src/sagemaker/rl/README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ SageMaker Reinforcement Learning Estimators
44

55
With Reinforcement Learning (RL) Estimators, you can train reinforcement learning models on Amazon SageMaker.
66

7-
Supported versions of Coach: ``0.10.1`` with TensorFlow, ``0.11.0`` with TensorFlow or MXNet.
7+
Supported versions of Coach: ``0.11.1``, ``0.10.1`` with TensorFlow, ``0.11.0`` with TensorFlow or MXNet.
88
For more information about Coach, see https://github.com/NervanaSystems/coach
99

1010
Supported versions of Ray: ``0.5.3`` with TensorFlow.
@@ -42,7 +42,7 @@ You can then create an ``RLEstimator`` with keyword arguments to point to this s
4242
4343
rl_estimator = RLEstimator(entry_point='coach-train.py',
4444
toolkit=RLToolkit.COACH,
45-
toolkit_version='0.11.0',
45+
toolkit_version='0.11.1',
4646
framework=RLFramework.TENSORFLOW,
4747
role='SageMakerRole',
4848
train_instance_type='ml.p3.2xlarge',

src/sagemaker/rl/estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@
4040
'tensorflow': '1.11',
4141
'mxnet': '1.3'
4242
},
43+
'0.11.1': {
44+
'tensorflow': '1.12',
45+
},
4346
'0.11': {
44-
'tensorflow': '1.11',
47+
'tensorflow': '1.12',
4548
'mxnet': '1.3'
4649
}
4750
},
@@ -69,7 +72,8 @@ class RLFramework(enum.Enum):
6972
class RLEstimator(Framework):
7073
"""Handle end-to-end training and deployment of custom RLEstimator code."""
7174

72-
COACH_LATEST_VERSION = '0.11.0'
75+
COACH_LATEST_VERSION_TF = '0.11.1'
76+
COACH_LATEST_VERSION_MXNET = '0.11.0'
7377
RAY_LATEST_VERSION = '0.5.3'
7478

7579
def __init__(self, entry_point, toolkit=None, toolkit_version=None, framework=None,

tests/conftest.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def pytest_addoption(parser):
3939
parser.addoption('--mxnet-full-version', action='store', default=MXNet.LATEST_VERSION)
4040
parser.addoption('--ei-mxnet-full-version', action='store', default=MXNet.LATEST_VERSION)
4141
parser.addoption('--pytorch-full-version', action='store', default=PyTorch.LATEST_VERSION)
42-
parser.addoption('--rl-coach-full-version', action='store',
43-
default=RLEstimator.COACH_LATEST_VERSION)
42+
parser.addoption('--rl-coach-mxnet-full-version', action='store',
43+
default=RLEstimator.COACH_LATEST_VERSION_MXNET)
44+
parser.addoption('--rl-coach-tf-full-version', action='store',
45+
default=RLEstimator.COACH_LATEST_VERSION_TF)
4446
parser.addoption('--rl-ray-full-version', action='store',
4547
default=RLEstimator.RAY_LATEST_VERSION)
4648
parser.addoption('--sklearn-full-version', action='store', default=SKLEARN_VERSION)
@@ -128,7 +130,7 @@ def tf_version(request):
128130
return request.param
129131

130132

131-
@pytest.fixture(scope='module', params=['0.10.1', '0.10.1', '0.11', '0.11.0'])
133+
@pytest.fixture(scope='module', params=['0.10.1', '0.10.1', '0.11', '0.11.0', '0.11.1'])
132134
def rl_coach_tf_version(request):
133135
return request.param
134136

@@ -164,8 +166,13 @@ def pytorch_full_version(request):
164166

165167

166168
@pytest.fixture(scope='module')
167-
def rl_coach_full_version(request):
168-
return request.config.getoption('--rl-coach-full-version')
169+
def rl_coach_mxnet_full_version(request):
170+
return request.config.getoption('--rl-coach-mxnet-full-version')
171+
172+
173+
@pytest.fixture(scope='module')
174+
def rl_coach_tf_full_version(request):
175+
return request.config.getoption('--rl-coach-tf-full-version')
169176

170177

171178
@pytest.fixture(scope='module')

tests/integ/test_rl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828
@pytest.mark.canary_quick
2929
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="RL images supports only Python 3.")
30-
def test_coach_mxnet(sagemaker_session, rl_coach_full_version):
31-
estimator = _test_coach(sagemaker_session, RLFramework.MXNET, rl_coach_full_version)
30+
def test_coach_mxnet(sagemaker_session, rl_coach_mxnet_full_version):
31+
estimator = _test_coach(sagemaker_session, RLFramework.MXNET, rl_coach_mxnet_full_version)
3232

3333
with timeout(minutes=15):
3434
estimator.fit(wait='False')
@@ -50,8 +50,8 @@ def test_coach_mxnet(sagemaker_session, rl_coach_full_version):
5050

5151

5252
@pytest.mark.skipif(PYTHON_VERSION != 'py3', reason="RL images supports only Python 3.")
53-
def test_coach_tf(sagemaker_session, rl_coach_full_version):
54-
estimator = _test_coach(sagemaker_session, RLFramework.TENSORFLOW, rl_coach_full_version)
53+
def test_coach_tf(sagemaker_session, rl_coach_tf_full_version):
54+
estimator = _test_coach(sagemaker_session, RLFramework.TENSORFLOW, rl_coach_tf_full_version)
5555

5656
with timeout(minutes=15):
5757
estimator.fit()

tests/unit/test_rl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _get_full_gpu_image_uri(toolkit, toolkit_version, framework):
7676

7777

7878
def _rl_estimator(sagemaker_session, toolkit=RLToolkit.COACH,
79-
toolkit_version=RLEstimator.COACH_LATEST_VERSION, framework=RLFramework.MXNET,
79+
toolkit_version=RLEstimator.COACH_LATEST_VERSION_MXNET, framework=RLFramework.MXNET,
8080
train_instance_type=None, base_job_name=None, **kwargs):
8181
return RLEstimator(entry_point=SCRIPT_PATH,
8282
toolkit=toolkit,
@@ -466,7 +466,7 @@ def test_wrong_framework_format(sagemaker_session):
466466
def test_wrong_toolkit_format(sagemaker_session):
467467
with pytest.raises(ValueError) as e:
468468
RLEstimator(toolkit='coach', framework=RLFramework.TENSORFLOW,
469-
toolkit_version=RLEstimator.COACH_LATEST_VERSION,
469+
toolkit_version=RLEstimator.COACH_LATEST_VERSION_TF,
470470
entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
471471
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
472472
framework_version=None)

0 commit comments

Comments
 (0)