Skip to content

Commit c8e0494

Browse files
authored
Fix hyperparameter name for detecting a tuning job (#38)
1 parent f0ac06e commit c8e0494

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/tf_container/train_entry_point.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _get_checkpoint_dir(env):
111111
checkpoint_path = env.hyperparameters['checkpoint_path']
112112

113113
# If this is not part of a tuning job, then we can just use the specified checkpoint path
114-
if 'algorithms_tuning_objective_metric' not in env.hyperparameters:
114+
if '_tuning_objective_metric' not in env.hyperparameters:
115115
return checkpoint_path
116116

117117
job_name = env.job_name

test/unit/test_train_entry_point.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
CHECKPOINT_PATH = 'customer/checkpoint/path'
2222
JOB_NAME = 'test-1234'
23+
TUNING_HYPERPARAMETER_NAME = '_tuning_objective_metric'
2324
TUNING_METRIC = 'some-metric'
2425

2526

@@ -70,7 +71,7 @@ def test_get_checkpoint_dir_with_job_name_in_path(train_entry_point_module):
7071
checkpoint_path_with_job_name = '{}/checkpoints'.format(JOB_NAME)
7172
hyperparameters = {
7273
'checkpoint_path': checkpoint_path_with_job_name,
73-
'algorithms_tuning_objective_metric': TUNING_METRIC,
74+
TUNING_HYPERPARAMETER_NAME: TUNING_METRIC,
7475
}
7576
env = Mock(name='env', hyperparameters=hyperparameters, job_name=JOB_NAME)
7677
checkpoint_dir = train_entry_point_module._get_checkpoint_dir(env)
@@ -81,7 +82,7 @@ def test_get_checkpoint_dir_with_job_name_in_path(train_entry_point_module):
8182
def test_get_checkpoint_dir_without_job_name_env(train_entry_point_module):
8283
hyperparameters = {
8384
'checkpoint_path': CHECKPOINT_PATH,
84-
'algorithms_tuning_objective_metric': TUNING_METRIC,
85+
TUNING_HYPERPARAMETER_NAME: TUNING_METRIC,
8586
}
8687
env = Mock(name='env', hyperparameters=hyperparameters, job_name=None)
8788
checkpoint_dir = train_entry_point_module._get_checkpoint_dir(env)
@@ -92,7 +93,7 @@ def test_get_checkpoint_dir_without_job_name_env(train_entry_point_module):
9293
def test_get_checkpoint_dir_appending_job_name(train_entry_point_module):
9394
hyperparameters = {
9495
'checkpoint_path': CHECKPOINT_PATH,
95-
'algorithms_tuning_objective_metric': TUNING_METRIC,
96+
TUNING_HYPERPARAMETER_NAME: TUNING_METRIC,
9697
}
9798
env = Mock(name='env', hyperparameters=hyperparameters, job_name=JOB_NAME)
9899
checkpoint_dir = train_entry_point_module._get_checkpoint_dir(env)

0 commit comments

Comments
 (0)