Skip to content

Commit 7ca5af4

Browse files
committed
add tuning step support
1 parent 0716e9f commit 7ca5af4

File tree

8 files changed

+658
-23
lines changed

8 files changed

+658
-23
lines changed

src/sagemaker/session.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,6 +2027,45 @@ def create_tuning_job(
20272027
"Only one of training_config and training_config_list should be provided."
20282028
)
20292029

2030+
tune_request = self._get_tuning_request(
2031+
job_name=job_name,
2032+
tuning_config=tuning_config,
2033+
training_config=training_config,
2034+
training_config_list=training_config_list,
2035+
warm_start_config=warm_start_config,
2036+
tags=tags,
2037+
)
2038+
2039+
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2040+
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
2041+
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
2042+
2043+
def _get_tuning_request(
2044+
self,
2045+
job_name,
2046+
tuning_config,
2047+
training_config=None,
2048+
training_config_list=None,
2049+
warm_start_config=None,
2050+
tags=None,
2051+
):
2052+
"""Construct CreateHyperParameterTuningJob request
2053+
2054+
Args:
2055+
job_name (str): Name of the tuning job being created.
2056+
tuning_config (dict): Configuration to launch the tuning job.
2057+
training_config (dict): Configuration to launch training jobs under the tuning job
2058+
using a single algorithm.
2059+
training_config_list (list[dict]): A list of configurations to launch training jobs
2060+
under the tuning job using one or multiple algorithms. Either training_config
2061+
or training_config_list should be provided, but not both.
2062+
warm_start_config (dict): Configuration defining the type of warm start and
2063+
other required configurations.
2064+
tags (list[dict]): List of tags for labeling the tuning job. For more, see
2065+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2066+
Returns:
2067+
dict: A dictionary for CreateHyperParameterTuningJob request
2068+
"""
20302069
tune_request = {
20312070
"HyperParameterTuningJobName": job_name,
20322071
"HyperParameterTuningJobConfig": self._map_tuning_config(**tuning_config),
@@ -2047,9 +2086,7 @@ def create_tuning_job(
20472086
if tags is not None:
20482087
tune_request["Tags"] = tags
20492088

2050-
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2051-
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
2052-
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
2089+
return tune_request
20532090

20542091
def describe_tuning_job(self, job_name):
20552092
"""Calls DescribeHyperParameterTuningJob API for the given job name, returns the response.

src/sagemaker/tuner.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ def _prepare_static_hyperparameters_for_tuning(self, include_cls_metadata=False)
346346
estimator_name: self._prepare_static_hyperparameters(
347347
estimator,
348348
self._hyperparameter_ranges_dict[estimator_name],
349-
include_cls_metadata.get(estimator_name, False),
349+
include_cls_metadata.get(estimator_name, False)
350+
if isinstance(include_cls_metadata, dict)
351+
else include_cls_metadata,
350352
)
351353
for (estimator_name, estimator) in self.estimator_dict.items()
352354
}
@@ -1460,6 +1462,22 @@ def start_new(cls, tuner, inputs):
14601462
sagemaker.tuner._TuningJob: Constructed object that captures all
14611463
information about the started job.
14621464
"""
1465+
tuner_args = cls._get_tuner_args(tuner, inputs)
1466+
tuner.sagemaker_session.create_tuning_job(**tuner_args)
1467+
1468+
return cls(tuner.sagemaker_session, tuner._current_job_name)
1469+
1470+
@classmethod
1471+
def _get_tuner_args(cls, tuner, inputs):
1472+
"""Gets a dict of arguments for a new Amazon SageMaker tuning job from the tuner
1473+
Args:
1474+
tuner (:class:`~sagemaker.tuner.HyperparameterTuner`):
1475+
The ``HyperparameterTuner`` instance that started the job.
1476+
inputs: Information about the training data. Please refer to the
1477+
``fit()`` method of the associated estimator.
1478+
Returns:
1479+
Dict: dict for `sagemaker.session.Session.tune` method
1480+
"""
14631481
warm_start_config_req = None
14641482
if tuner.warm_start_config:
14651483
warm_start_config_req = tuner.warm_start_config.to_input_req()
@@ -1506,8 +1524,7 @@ def start_new(cls, tuner, inputs):
15061524
for estimator_name in sorted(tuner.estimator_dict.keys())
15071525
]
15081526

1509-
tuner.sagemaker_session.create_tuning_job(**tuner_args)
1510-
return cls(tuner.sagemaker_session, tuner._current_job_name)
1527+
return tuner_args
15111528

15121529
@staticmethod
15131530
def _prepare_training_config(

src/sagemaker/workflow/properties.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The properties definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Dict, Union
16+
from typing import Dict, Union, List
1717

1818
import attr
1919

@@ -40,27 +40,35 @@ def __new__(mcs, *args, **kwargs):
4040
class Properties(metaclass=PropertiesMeta):
4141
"""Properties for use in workflow expressions."""
4242

43-
def __init__(self, path: str, shape_name: str = None):
43+
def __init__(
44+
self,
45+
path: str,
46+
shape_name: str = None,
47+
shape_names: List[str] = None,
48+
):
4449
"""Create a Properties instance representing the given shape.
4550
4651
Args:
4752
path (str): The parent path of the Properties instance.
4853
shape_name (str): The botocore sagemaker service model shape name.
54+
shape_names (str): A List of the botocore sagemaker service model shape name.
4955
"""
5056
self._path = path
51-
self._shape_name = shape_name
52-
53-
shape = Properties._shapes.get(self._shape_name, {})
54-
shape_type = shape.get("type")
55-
if shape_type in Properties._primitive_types:
56-
self.__str__ = shape_name
57-
elif shape_type == "structure":
58-
members = shape["members"]
59-
for key, info in members.items():
60-
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
61-
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
62-
else:
63-
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
57+
shape_names = [] if shape_names is None else shape_names
58+
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
59+
60+
for name in self._shape_names:
61+
shape = Properties._shapes.get(name, {})
62+
shape_type = shape.get("type")
63+
if shape_type in Properties._primitive_types:
64+
self.__str__ = name
65+
elif shape_type == "structure":
66+
members = shape["members"]
67+
for key, info in members.items():
68+
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
69+
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
70+
else:
71+
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
6472

6573
@property
6674
def expr(self):
@@ -77,8 +85,10 @@ def __init__(self, path: str, shape_name: str = None):
7785
Args:
7886
path (str): The parent path of the PropertiesList instance.
7987
shape_name (str): The botocore sagemaker service model shape name.
88+
root_shape_name (str): The botocore sagemaker service model shape name.
8089
"""
8190
super(PropertiesList, self).__init__(path, shape_name)
91+
self.shape_name = shape_name
8292
self._items: Dict[Union[int, str], Properties] = dict()
8393

8494
def __getitem__(self, item: Union[int, str]):
@@ -88,7 +98,7 @@ def __getitem__(self, item: Union[int, str]):
8898
item (Union[int, str]): The index of the item in sequence.
8999
"""
90100
if item not in self._items.keys():
91-
shape = Properties._shapes.get(self._shape_name)
101+
shape = Properties._shapes.get(self.shape_name)
92102
member = shape["member"]["shape"]
93103
if isinstance(item, str):
94104
property_item = Properties(f"{self._path}['{item}']", member)

src/sagemaker/workflow/steps.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Processor,
3535
)
3636
from sagemaker.transformer import Transformer, _TransformJob
37+
from sagemaker.tuner import HyperparameterTuner, _TuningJob
3738
from sagemaker.workflow.entities import (
3839
DefaultEnumMeta,
3940
Entity,
@@ -55,6 +56,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5556
TRAINING = "Training"
5657
TRANSFORM = "Transform"
5758
CALLBACK = "Callback"
59+
TUNING = "Tuning"
5860

5961

6062
@attr.s
@@ -96,6 +98,7 @@ def add_depends_on(self, step_names: List[str]):
9698
"""Add step names to the current step depends on list"""
9799
if not step_names:
98100
return
101+
99102
if not self.depends_on:
100103
self.depends_on = []
101104
self.depends_on.extend(step_names)
@@ -417,3 +420,106 @@ def to_request(self) -> RequestType:
417420
property_file.expr for property_file in self.property_files
418421
]
419422
return request_dict
423+
424+
425+
class TuningStep(Step):
426+
"""Tuning step for workflow."""
427+
428+
def __init__(
429+
self,
430+
name: str,
431+
tuner: HyperparameterTuner,
432+
inputs=None,
433+
job_arguments: List[str] = None,
434+
cache_config: CacheConfig = None,
435+
depends_on: List[str] = None,
436+
):
437+
"""Construct a TuningStep, given a `HyperparameterTuner` instance.
438+
439+
In addition to the tuner instance, the other arguments are those that are supplied to
440+
the `fit` method of the `sagemaker.tuner.HyperparameterTuner`.
441+
442+
Args:
443+
name (str): The name of the tuning step.
444+
tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
445+
inputs: Information about the training data. Please refer to the
446+
``fit()`` method of the associated estimator, as this can take
447+
any of the following forms:
448+
449+
* (str) - The S3 location where training data is saved.
450+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
451+
If using multiple channels for training data, you can specify
452+
a dict mapping channel names to strings or
453+
:func:`~sagemaker.inputs.TrainingInput` objects.
454+
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
455+
that can provide additional information about the training dataset.
456+
See :func:`sagemaker.inputs.TrainingInput` for full details.
457+
* (sagemaker.session.FileSystemInput) - channel configuration for
458+
a file system data source that can provide additional information as well as
459+
the path to the training dataset.
460+
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
461+
Amazon :class:~`Record` objects serialized and stored in S3.
462+
For use with an estimator for an Amazon algorithm.
463+
* (sagemaker.amazon.amazon_estimator.FileSystemRecordSet) -
464+
Amazon SageMaker channel configuration for a file system data source for
465+
Amazon algorithms.
466+
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
467+
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
468+
where each instance is a different channel of training data.
469+
* (list[sagemaker.amazon.amazon_estimator.FileSystemRecordSet]) - A list of
470+
:class:~`sagemaker.amazon.amazon_estimator.FileSystemRecordSet` objects,
471+
where each instance is a different channel of training data.
472+
job_arguments (List[str]): A list of strings to be passed into the processing job.
473+
Defaults to `None`.
474+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
475+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
476+
depends on
477+
"""
478+
super(TuningStep, self).__init__(name, StepTypeEnum.TUNING, depends_on)
479+
self.tuner = tuner
480+
self.inputs = inputs
481+
self.job_arguments = job_arguments
482+
self._properties = Properties(
483+
path=f"Steps.{name}",
484+
shape_names=[
485+
"DescribeHyperParameterTuningJobResponse",
486+
"ListTrainingJobsForHyperParameterTuningJobResponse",
487+
],
488+
)
489+
self.cache_config = cache_config
490+
491+
@property
492+
def arguments(self) -> RequestType:
493+
"""The arguments dict that is used to call `create_hyper_parameter_tuning_job`.
494+
495+
NOTE: The CreateHyperParameterTuningJob request is not quite the
496+
args list that workflow needs.
497+
The HyperParameterTuningJobName attribute cannot be included.
498+
"""
499+
if self.tuner.estimator is not None:
500+
self.tuner.estimator._prepare_for_training()
501+
else:
502+
for _, estimator in self.tuner.estimator_dict.items():
503+
estimator._prepare_for_training()
504+
505+
self.tuner._prepare_for_tuning()
506+
tuner_args = _TuningJob._get_tuner_args(self.tuner, self.inputs)
507+
request_dict = self.tuner.sagemaker_session._get_tuning_request(**tuner_args)
508+
request_dict.pop("HyperParameterTuningJobName")
509+
510+
return request_dict
511+
512+
@property
513+
def properties(self):
514+
"""A Properties object representing the `DescribeHyperParameterTuningJobResponse` and
515+
`ListTrainingJobsForHyperParameterTuningJobResponse` data model.
516+
"""
517+
return self._properties
518+
519+
def to_request(self) -> RequestType:
520+
"""Updates the dictionary with cache configuration."""
521+
request_dict = super().to_request()
522+
if self.cache_config:
523+
request_dict.update(self.cache_config.config)
524+
525+
return request_dict

tests/data/pytorch_mnist/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def train(args):
182182
accuracy = test(model, test_loader, device)
183183
save_model(model, args.model_dir)
184184

185-
logger.debug("Overall test accuracy: {}".format(accuracy))
185+
logger.debug("Overall test accuracy: {};".format(accuracy))
186186

187187

188188
def test(model, test_loader, device):

0 commit comments

Comments
 (0)