Skip to content

Commit 903a5f2

Browse files
authored
Fix hyperparameter strategy docs (aws#5045)
1 parent 5682c42 commit 903a5f2

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

src/sagemaker/tuner.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818
import inspect
1919
import json
2020
import logging
21-
2221
from enum import Enum
23-
from typing import Union, Dict, Optional, List, Set
22+
from typing import Dict, List, Optional, Set, Union
2423

2524
import sagemaker
2625
from sagemaker.amazon.amazon_estimator import (
27-
RecordSet,
2826
AmazonAlgorithmEstimatorBase,
2927
FileSystemRecordSet,
28+
RecordSet,
3029
)
3130
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
3231
from sagemaker.analytics import HyperparameterTuningJobAnalytics
3332
from sagemaker.deprecations import removed_function
34-
from sagemaker.estimator import Framework, EstimatorBase
35-
from sagemaker.inputs import TrainingInput, FileSystemInput
33+
from sagemaker.estimator import EstimatorBase, Framework
34+
from sagemaker.inputs import FileSystemInput, TrainingInput
3635
from sagemaker.job import _Job
3736
from sagemaker.jumpstart.utils import (
3837
add_jumpstart_uri_tags,
@@ -44,18 +43,17 @@
4443
IntegerParameter,
4544
ParameterRange,
4645
)
47-
from sagemaker.workflow.entities import PipelineVariable
48-
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
49-
5046
from sagemaker.session import Session
5147
from sagemaker.utils import (
48+
Tags,
5249
base_from_name,
5350
base_name_from_image,
51+
format_tags,
5452
name_from_base,
5553
to_string,
56-
format_tags,
57-
Tags,
5854
)
55+
from sagemaker.workflow.entities import PipelineVariable
56+
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
5957

6058
AMAZON_ESTIMATOR_MODULE = "sagemaker"
6159
AMAZON_ESTIMATOR_CLS_NAMES = {
@@ -133,15 +131,12 @@ def __init__(
133131

134132
if warm_start_type not in list(WarmStartTypes):
135133
raise ValueError(
136-
"Invalid type: {}, valid warm start types are: {}".format(
137-
warm_start_type, list(WarmStartTypes)
138-
)
134+
f"Invalid type: {warm_start_type}, "
135+
f"valid warm start types are: {list(WarmStartTypes)}"
139136
)
140137

141138
if not parents:
142-
raise ValueError(
143-
"Invalid parents: {}, parents should not be None/empty".format(parents)
144-
)
139+
raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty")
145140

146141
self.type = warm_start_type
147142
self.parents = set(parents)
@@ -1455,9 +1450,7 @@ def _get_best_training_job(self):
14551450
return tuning_job_describe_result["BestTrainingJob"]
14561451
except KeyError:
14571452
raise Exception(
1458-
"Best training job not available for tuning job: {}".format(
1459-
self.latest_tuning_job.name
1460-
)
1453+
f"Best training job not available for tuning job: {self.latest_tuning_job.name}"
14611454
)
14621455

14631456
def _ensure_last_tuning_job(self):
@@ -1920,8 +1913,11 @@ def create(
19201913
:meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches.
19211914
If not specified, a default job name is generated,
19221915
based on the training image name and current timestamp.
1923-
strategy (str): Strategy to be used for hyperparameter estimations
1924-
(default: 'Bayesian').
1916+
strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations.
1917+
More information about different strategies:
1918+
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html.
1919+
Available options are: 'Bayesian', 'Random', 'Hyperband',
1920+
'Grid' (default: 'Bayesian')
19251921
strategy_config (dict): The configuration for a training job launched by a
19261922
hyperparameter tuning job.
19271923
completion_criteria_config (dict): The configuration for tuning job completion criteria.
@@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa
20802076
return
20812077

20822078
if not isinstance(value, dict):
2083-
raise ValueError(
2084-
"Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys)
2085-
)
2079+
raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys")
20862080

20872081
value_keys = sorted(value.keys())
20882082

20892083
if require_same_keys:
20902084
if value_keys != allowed_keys:
20912085
raise ValueError(
2092-
"The keys of argument '{}' must be the same as {}".format(name, allowed_keys)
2086+
f"The keys of argument '{name}' must be the same as {allowed_keys}"
20932087
)
20942088
else:
20952089
if not set(value_keys).issubset(set(allowed_keys)):
20962090
raise ValueError(
2097-
"The keys of argument '{}' must be a subset of {}".format(name, allowed_keys)
2091+
f"The keys of argument '{name}' must be a subset of {allowed_keys}"
20982092
)
20992093

21002094
def _add_estimator(

0 commit comments

Comments
 (0)