Skip to content

Commit 584b55f

Browse files
committed
modify train_ray.py for compatibility
1 parent af57234 commit 584b55f

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/sagemaker/rl/estimator.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,11 @@ def default_metric_definitions(cls, toolkit):
508508
return [
509509
{
510510
"Name": "episode_reward_mean",
511-
"Regex": "episode_reward_mean: (%s)" % float_regex,
511+
"Regex": "episode_reward_mean: {}".format(float_regex),
512+
},
513+
{
514+
"Name": "episode_reward_max",
515+
"Regex": "episode_reward_max: {}".format(float_regex),
512516
},
513-
{"Name": "episode_reward_max", "Regex": "episode_reward_max: (%s)" % float_regex, },
514517
]
515518
raise ValueError("An unknown RLToolkit enum was passed in. toolkit: {}".format(toolkit))

tests/data/ray_cartpole/train_ray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
config["num_gpus"] = int(os.environ.get("SM_NUM_GPUS", 0))
1111
checkpoint_dir = os.environ.get("SM_MODEL_DIR", "/Users/nadzeya/gym")
1212
config["num_workers"] = 1
13-
agent = ppo.PPOAgent(config=config, env="CartPole-v0")
13+
agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
1414

1515
# Can optionally call agent.restore(path) to load a checkpoint.
1616

0 commit comments

Comments
 (0)