|
18 | 18 | import inspect
|
19 | 19 | import json
|
20 | 20 | import logging
|
21 |
| - |
22 | 21 | from enum import Enum
|
23 |
| -from typing import Union, Dict, Optional, List, Set |
| 22 | +from typing import Dict, List, Optional, Set, Union |
24 | 23 |
|
25 | 24 | import sagemaker
|
26 | 25 | from sagemaker.amazon.amazon_estimator import (
|
27 |
| - RecordSet, |
28 | 26 | AmazonAlgorithmEstimatorBase,
|
29 | 27 | FileSystemRecordSet,
|
| 28 | + RecordSet, |
30 | 29 | )
|
31 | 30 | from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
|
32 | 31 | from sagemaker.analytics import HyperparameterTuningJobAnalytics
|
33 | 32 | from sagemaker.deprecations import removed_function
|
34 |
| -from sagemaker.estimator import Framework, EstimatorBase |
35 |
| -from sagemaker.inputs import TrainingInput, FileSystemInput |
| 33 | +from sagemaker.estimator import EstimatorBase, Framework |
| 34 | +from sagemaker.inputs import FileSystemInput, TrainingInput |
36 | 35 | from sagemaker.job import _Job
|
37 | 36 | from sagemaker.jumpstart.utils import (
|
38 | 37 | add_jumpstart_uri_tags,
|
|
44 | 43 | IntegerParameter,
|
45 | 44 | ParameterRange,
|
46 | 45 | )
|
47 |
| -from sagemaker.workflow.entities import PipelineVariable |
48 |
| -from sagemaker.workflow.pipeline_context import runnable_by_pipeline |
49 |
| - |
50 | 46 | from sagemaker.session import Session
|
51 | 47 | from sagemaker.utils import (
|
| 48 | + Tags, |
52 | 49 | base_from_name,
|
53 | 50 | base_name_from_image,
|
| 51 | + format_tags, |
54 | 52 | name_from_base,
|
55 | 53 | to_string,
|
56 |
| - format_tags, |
57 |
| - Tags, |
58 | 54 | )
|
| 55 | +from sagemaker.workflow.entities import PipelineVariable |
| 56 | +from sagemaker.workflow.pipeline_context import runnable_by_pipeline |
59 | 57 |
|
60 | 58 | AMAZON_ESTIMATOR_MODULE = "sagemaker"
|
61 | 59 | AMAZON_ESTIMATOR_CLS_NAMES = {
|
@@ -133,15 +131,12 @@ def __init__(
|
133 | 131 |
|
134 | 132 | if warm_start_type not in list(WarmStartTypes):
|
135 | 133 | raise ValueError(
|
136 |
| - "Invalid type: {}, valid warm start types are: {}".format( |
137 |
| - warm_start_type, list(WarmStartTypes) |
138 |
| - ) |
| 134 | + f"Invalid type: {warm_start_type}, " |
| 135 | + f"valid warm start types are: {list(WarmStartTypes)}" |
139 | 136 | )
|
140 | 137 |
|
141 | 138 | if not parents:
|
142 |
| - raise ValueError( |
143 |
| - "Invalid parents: {}, parents should not be None/empty".format(parents) |
144 |
| - ) |
| 139 | + raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty") |
145 | 140 |
|
146 | 141 | self.type = warm_start_type
|
147 | 142 | self.parents = set(parents)
|
@@ -1455,9 +1450,7 @@ def _get_best_training_job(self):
|
1455 | 1450 | return tuning_job_describe_result["BestTrainingJob"]
|
1456 | 1451 | except KeyError:
|
1457 | 1452 | raise Exception(
|
1458 |
| - "Best training job not available for tuning job: {}".format( |
1459 |
| - self.latest_tuning_job.name |
1460 |
| - ) |
| 1453 | + f"Best training job not available for tuning job: {self.latest_tuning_job.name}" |
1461 | 1454 | )
|
1462 | 1455 |
|
1463 | 1456 | def _ensure_last_tuning_job(self):
|
@@ -1920,8 +1913,11 @@ def create(
|
1920 | 1913 | :meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches.
|
1921 | 1914 | If not specified, a default job name is generated,
|
1922 | 1915 | based on the training image name and current timestamp.
|
1923 |
| - strategy (str): Strategy to be used for hyperparameter estimations |
1924 |
| - (default: 'Bayesian'). |
| 1916 | + strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations. |
| 1917 | + More information about different strategies: |
| 1918 | + https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html. |
| 1919 | + Available options are: 'Bayesian', 'Random', 'Hyperband', |
| 1920 | + 'Grid' (default: 'Bayesian') |
1925 | 1921 | strategy_config (dict): The configuration for a training job launched by a
|
1926 | 1922 | hyperparameter tuning job.
|
1927 | 1923 | completion_criteria_config (dict): The configuration for tuning job completion criteria.
|
@@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa
|
2080 | 2076 | return
|
2081 | 2077 |
|
2082 | 2078 | if not isinstance(value, dict):
|
2083 |
| - raise ValueError( |
2084 |
| - "Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys) |
2085 |
| - ) |
| 2079 | + raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys") |
2086 | 2080 |
|
2087 | 2081 | value_keys = sorted(value.keys())
|
2088 | 2082 |
|
2089 | 2083 | if require_same_keys:
|
2090 | 2084 | if value_keys != allowed_keys:
|
2091 | 2085 | raise ValueError(
|
2092 |
| - "The keys of argument '{}' must be the same as {}".format(name, allowed_keys) |
| 2086 | + f"The keys of argument '{name}' must be the same as {allowed_keys}" |
2093 | 2087 | )
|
2094 | 2088 | else:
|
2095 | 2089 | if not set(value_keys).issubset(set(allowed_keys)):
|
2096 | 2090 | raise ValueError(
|
2097 |
| - "The keys of argument '{}' must be a subset of {}".format(name, allowed_keys) |
| 2091 | + f"The keys of argument '{name}' must be a subset of {allowed_keys}" |
2098 | 2092 | )
|
2099 | 2093 |
|
2100 | 2094 | def _add_estimator(
|
|
0 commit comments