Skip to content

Improvement of the tuner documentation #4506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""

from __future__ import absolute_import

import importlib
Expand Down Expand Up @@ -641,8 +642,11 @@ def __init__(
extract the metric from the logs. This should be defined only
for hyperparameter tuning jobs that don't use an Amazon
algorithm.
strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations
(default: 'Bayesian').
strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations.
More information about different strategies:
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html.
Available options are: 'Bayesian', 'Random', 'Hyperband',
'Grid' (default: 'Bayesian')
objective_type (str or PipelineVariable): The type of the objective metric for
evaluating training jobs. This value can be either 'Minimize' or
'Maximize' (default: 'Maximize').
Expand Down Expand Up @@ -759,7 +763,8 @@ def __init__(
self.autotune = autotune

def override_resource_config(
self, instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]]
self,
instance_configs: Union[List[InstanceConfig], Dict[str, List[InstanceConfig]]],
):
"""Override the instance configuration of the estimators used by the tuner.

Expand Down Expand Up @@ -966,7 +971,7 @@ def fit(
include_cls_metadata: Union[bool, Dict[str, bool]] = False,
estimator_kwargs: Optional[Dict[str, dict]] = None,
wait: bool = True,
**kwargs
**kwargs,
):
"""Start a hyperparameter tuning job.

Expand Down Expand Up @@ -1055,7 +1060,7 @@ def _fit_with_estimator_dict(self, inputs, job_name, include_cls_metadata, estim
allowed_keys=estimator_names,
)

for (estimator_name, estimator) in self.estimator_dict.items():
for estimator_name, estimator in self.estimator_dict.items():
ins = inputs.get(estimator_name, None) if inputs is not None else None
args = estimator_kwargs.get(estimator_name, {}) if estimator_kwargs is not None else {}
self._prepare_estimator_for_tuning(estimator, ins, job_name, **args)
Expand Down Expand Up @@ -1282,7 +1287,7 @@ def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, jo
objective_metric_name_dict=objective_metric_name_dict,
hyperparameter_ranges_dict=hyperparameter_ranges_dict,
metric_definitions_dict=metric_definitions_dict,
**init_params
**init_params,
)

def deploy(
Expand All @@ -1297,7 +1302,7 @@ def deploy(
model_name=None,
kms_key=None,
data_capture_config=None,
**kwargs
**kwargs,
):
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint.

Expand Down Expand Up @@ -1363,7 +1368,7 @@ def deploy(
model_name=model_name,
kms_key=kms_key,
data_capture_config=data_capture_config,
**kwargs
**kwargs,
)

def stop_tuning_job(self):
Expand Down