Skip to content

Commit 67fa816

Browse files
longyuzhaoLongyu Zhaonavinsoni
authored
feature: modify RLEstimator to use newly generated Ray image (1.6.0) (#2717)
* Modify RLEstimator to use newly generated Ray image (1.6.0) Changes have been tested locally with Cartpole env. Rev-1: fix the image tag name and unit tests. * Fix integration tests due to new Ray version ('webui_host' has been deprecated) Co-authored-by: Longyu Zhao <[email protected]> Co-authored-by: Navin Soni <[email protected]>
1 parent 46b3cbf commit 67fa816

File tree

5 files changed

+53
-4
lines changed

5 files changed

+53
-4
lines changed

src/sagemaker/image_uri_config/ray-pytorch.json

+20
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@
2121
},
2222
"repository": "sagemaker-rl-ray-container",
2323
"tag_prefix": "ray-0.8.5-torch"
24+
},
25+
"1.6.0": {
26+
"py_versions": ["py36"],
27+
"registries": {
28+
"ap-northeast-1": "462105765813",
29+
"ap-northeast-2": "462105765813",
30+
"ap-south-1": "462105765813",
31+
"ap-southeast-1": "462105765813",
32+
"ap-southeast-2": "462105765813",
33+
"ca-central-1": "462105765813",
34+
"eu-central-1": "462105765813",
35+
"eu-west-1": "462105765813",
36+
"eu-west-2": "462105765813",
37+
"us-east-1": "462105765813",
38+
"us-east-2": "462105765813",
39+
"us-west-1": "462105765813",
40+
"us-west-2": "462105765813"
41+
},
42+
"repository": "sagemaker-rl-ray-container",
43+
"tag_prefix": "ray-1.6.0-torch"
2444
}
2545
}
2646
}

src/sagemaker/image_uri_config/ray-tensorflow.json

+20
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,26 @@
165165
},
166166
"repository": "sagemaker-rl-ray-container",
167167
"tag_prefix": "ray-0.8.5-tf"
168+
},
169+
"1.6.0": {
170+
"py_versions": ["py37"],
171+
"registries": {
172+
"ap-northeast-1": "462105765813",
173+
"ap-northeast-2": "462105765813",
174+
"ap-south-1": "462105765813",
175+
"ap-southeast-1": "462105765813",
176+
"ap-southeast-2": "462105765813",
177+
"ca-central-1": "462105765813",
178+
"eu-central-1": "462105765813",
179+
"eu-west-1": "462105765813",
180+
"eu-west-2": "462105765813",
181+
"us-east-1": "462105765813",
182+
"us-east-2": "462105765813",
183+
"us-west-1": "462105765813",
184+
"us-west-2": "462105765813"
185+
},
186+
"repository": "sagemaker-rl-ray-container",
187+
"tag_prefix": "ray-1.6.0-tf"
168188
}
169189
}
170190
}

src/sagemaker/rl/estimator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"0.6": {"tensorflow": "1.12"},
4646
"0.8.2": {"tensorflow": "2.1"},
4747
"0.8.5": {"tensorflow": "2.1", "pytorch": "1.5"},
48+
"1.6.0": {"tensorflow": "2.5.0", "pytorch": "1.8.1"},
4849
},
4950
}
5051

@@ -69,7 +70,7 @@ class RLEstimator(Framework):
6970

7071
COACH_LATEST_VERSION_TF = "0.11.1"
7172
COACH_LATEST_VERSION_MXNET = "0.11.0"
72-
RAY_LATEST_VERSION = "0.8.5"
73+
RAY_LATEST_VERSION = "1.6.0"
7374

7475
def __init__(
7576
self,

tests/data/ray_cartpole/train_ray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from ray.tune.logger import pretty_print
66

77
# Based on https://github.com/ray-project/ray/blob/master/doc/source/rllib-training.rst#python-api
8-
ray.init(log_to_driver=False, webui_host="127.0.0.1")
8+
ray.init(log_to_driver=False)
99
config = ppo.DEFAULT_CONFIG.copy()
1010
config["num_gpus"] = int(os.environ.get("SM_NUM_GPUS", 0))
11-
checkpoint_dir = os.environ.get("SM_MODEL_DIR", "/Users/nadzeya/gym")
11+
checkpoint_dir = os.environ.get("SM_MODEL_DIR", "/tmp")
1212
config["num_workers"] = 1
1313
agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
1414

tests/unit/sagemaker/image_uris/test_rl.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,15 @@ def test_ray_tf(ray_tensorflow_version):
8686

8787

8888
def _expected_ray_tf_uri(ray_tf_version, processor):
89-
if Version(ray_tf_version) > Version("0.6.5"):
89+
if Version(ray_tf_version) > Version("1.0.0"):
90+
return expected_uris.framework_uri(
91+
"sagemaker-rl-ray-container",
92+
_version_for_tag("ray", ray_tf_version, "tf", True),
93+
RL_ACCOUNT,
94+
py_version="py37",
95+
processor=processor,
96+
)
97+
elif Version(ray_tf_version) > Version("0.6.5"):
9098
return expected_uris.framework_uri(
9199
"sagemaker-rl-ray-container",
92100
_version_for_tag("ray", ray_tf_version, "tf", True),

0 commit comments

Comments
 (0)