Skip to content

feature: support for describing hyperparameter tuning job #1594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,20 @@ def create_tuning_job(
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)

def describe_tuning_job(self, job_name):
"""Calls the DescribeHyperParameterTuningJob API for the given job name
and returns the response.

Args:
job_name (str): The name of the hyperparameter tuning job to describe.

Returns:
dict: A dictionary response with the hyperparameter tuning job description.
"""
return self.sagemaker_client.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=job_name
)

@classmethod
def _map_tuning_config(
cls,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,10 @@ def stop_tuning_job(self):
self._ensure_last_tuning_job()
self.latest_tuning_job.stop()

def describe(self):
"""Returns a response from the DescribrHyperParameterTuningJob API call."""
return self.sagemaker_session.describe_tuning_job(self._current_job_name)

def wait(self):
"""Wait for latest hyperparameter tuning job to finish."""
self._ensure_last_tuning_job()
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,3 +2125,11 @@ def test_list_candidates_for_auto_ml_job_with_optional_args(sagemaker_session):
sagemaker_session.sagemaker_client.list_candidates_for_auto_ml_job.assert_called_with(
**COMPLETE_EXPECTED_LIST_CANDIDATES_ARGS
)


def test_describe_tuning_Job(sagemaker_session):
job_name = "hyper-parameter-tuning"
sagemaker_session.describe_tuning_job(job_name=job_name)
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_called_with(
HyperParameterTuningJobName=job_name
)
5 changes: 5 additions & 0 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,11 @@ def test_create_warm_start_tuner_with_single_estimator_dict(
assert tuner.warm_start_config.parents == additional_parents


def test_describe(tuner):
tuner.describe()
tuner.sagemaker_session.describe_tuning_job.assert_called_once()


def _convert_tuning_job_details(job_details, estimator_name):
"""Convert a tuning job description using the 'TrainingJobDefinition' field into a new one using a single-item
'TrainingJobDefinitions' field (list).
Expand Down