Skip to content

Commit efdf3ec

Browse files
jerrypeng7773ahsan-z-khanapogupta2018
authored
add support for SageMaker workflow tuning step (#2497)
* add helper function to generate no-op (data ingestion only) recipe * separate flow generation by source input type + move generation helpers to sagemaker.wrangler.ingestion * create an internal helper function to generate output node * add ingestion test using dw processor via pipeline execution * verify the fg query df * fix tests * add tuning step support * fix docstyle check * add helper function to get tuning step top performing model s3 uri Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: apogupta2018 <[email protected]>
1 parent e7ef9e3 commit efdf3ec

File tree

8 files changed

+735
-23
lines changed

8 files changed

+735
-23
lines changed

src/sagemaker/session.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,45 @@ def create_tuning_job(
20332033
"Only one of training_config and training_config_list should be provided."
20342034
)
20352035

2036+
tune_request = self._get_tuning_request(
2037+
job_name=job_name,
2038+
tuning_config=tuning_config,
2039+
training_config=training_config,
2040+
training_config_list=training_config_list,
2041+
warm_start_config=warm_start_config,
2042+
tags=tags,
2043+
)
2044+
2045+
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2046+
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
2047+
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
2048+
2049+
def _get_tuning_request(
2050+
self,
2051+
job_name,
2052+
tuning_config,
2053+
training_config=None,
2054+
training_config_list=None,
2055+
warm_start_config=None,
2056+
tags=None,
2057+
):
2058+
"""Construct CreateHyperParameterTuningJob request
2059+
2060+
Args:
2061+
job_name (str): Name of the tuning job being created.
2062+
tuning_config (dict): Configuration to launch the tuning job.
2063+
training_config (dict): Configuration to launch training jobs under the tuning job
2064+
using a single algorithm.
2065+
training_config_list (list[dict]): A list of configurations to launch training jobs
2066+
under the tuning job using one or multiple algorithms. Either training_config
2067+
or training_config_list should be provided, but not both.
2068+
warm_start_config (dict): Configuration defining the type of warm start and
2069+
other required configurations.
2070+
tags (list[dict]): List of tags for labeling the tuning job. For more, see
2071+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2072+
Returns:
2073+
dict: A dictionary for CreateHyperParameterTuningJob request
2074+
"""
20362075
tune_request = {
20372076
"HyperParameterTuningJobName": job_name,
20382077
"HyperParameterTuningJobConfig": self._map_tuning_config(**tuning_config),
@@ -2053,9 +2092,7 @@ def create_tuning_job(
20532092
if tags is not None:
20542093
tune_request["Tags"] = tags
20552094

2056-
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
2057-
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
2058-
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
2095+
return tune_request
20592096

20602097
def describe_tuning_job(self, job_name):
20612098
"""Calls DescribeHyperParameterTuningJob API for the given job name, returns the response.

src/sagemaker/tuner.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ def _prepare_static_hyperparameters_for_tuning(self, include_cls_metadata=False)
346346
estimator_name: self._prepare_static_hyperparameters(
347347
estimator,
348348
self._hyperparameter_ranges_dict[estimator_name],
349-
include_cls_metadata.get(estimator_name, False),
349+
include_cls_metadata.get(estimator_name, False)
350+
if isinstance(include_cls_metadata, dict)
351+
else include_cls_metadata,
350352
)
351353
for (estimator_name, estimator) in self.estimator_dict.items()
352354
}
@@ -1460,6 +1462,23 @@ def start_new(cls, tuner, inputs):
14601462
sagemaker.tuner._TuningJob: Constructed object that captures all
14611463
information about the started job.
14621464
"""
1465+
tuner_args = cls._get_tuner_args(tuner, inputs)
1466+
tuner.sagemaker_session.create_tuning_job(**tuner_args)
1467+
1468+
return cls(tuner.sagemaker_session, tuner._current_job_name)
1469+
1470+
@classmethod
1471+
def _get_tuner_args(cls, tuner, inputs):
1472+
"""Gets a dict of arguments for a new Amazon SageMaker tuning job from the tuner
1473+
1474+
Args:
1475+
tuner (:class:`~sagemaker.tuner.HyperparameterTuner`):
1476+
The ``HyperparameterTuner`` instance that started the job.
1477+
inputs: Information about the training data. Please refer to the
1478+
``fit()`` method of the associated estimator.
1479+
Returns:
1480+
Dict: dict for `sagemaker.session.Session.tune` method
1481+
"""
14631482
warm_start_config_req = None
14641483
if tuner.warm_start_config:
14651484
warm_start_config_req = tuner.warm_start_config.to_input_req()
@@ -1506,8 +1525,7 @@ def start_new(cls, tuner, inputs):
15061525
for estimator_name in sorted(tuner.estimator_dict.keys())
15071526
]
15081527

1509-
tuner.sagemaker_session.create_tuning_job(**tuner_args)
1510-
return cls(tuner.sagemaker_session, tuner._current_job_name)
1528+
return tuner_args
15111529

15121530
@staticmethod
15131531
def _prepare_training_config(

src/sagemaker/workflow/properties.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""The properties definitions for workflow."""
1414
from __future__ import absolute_import
1515

16-
from typing import Dict, Union
16+
from typing import Dict, Union, List
1717

1818
import attr
1919

@@ -40,27 +40,35 @@ def __new__(mcs, *args, **kwargs):
4040
class Properties(metaclass=PropertiesMeta):
4141
"""Properties for use in workflow expressions."""
4242

43-
def __init__(self, path: str, shape_name: str = None):
43+
def __init__(
44+
self,
45+
path: str,
46+
shape_name: str = None,
47+
shape_names: List[str] = None,
48+
):
4449
"""Create a Properties instance representing the given shape.
4550
4651
Args:
4752
path (str): The parent path of the Properties instance.
4853
shape_name (str): The botocore sagemaker service model shape name.
54+
shape_names (str): A List of the botocore sagemaker service model shape name.
4955
"""
5056
self._path = path
51-
self._shape_name = shape_name
52-
53-
shape = Properties._shapes.get(self._shape_name, {})
54-
shape_type = shape.get("type")
55-
if shape_type in Properties._primitive_types:
56-
self.__str__ = shape_name
57-
elif shape_type == "structure":
58-
members = shape["members"]
59-
for key, info in members.items():
60-
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
61-
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
62-
else:
63-
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
57+
shape_names = [] if shape_names is None else shape_names
58+
self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names
59+
60+
for name in self._shape_names:
61+
shape = Properties._shapes.get(name, {})
62+
shape_type = shape.get("type")
63+
if shape_type in Properties._primitive_types:
64+
self.__str__ = name
65+
elif shape_type == "structure":
66+
members = shape["members"]
67+
for key, info in members.items():
68+
if Properties._shapes.get(info["shape"], {}).get("type") == "list":
69+
self.__dict__[key] = PropertiesList(f"{path}.{key}", info["shape"])
70+
else:
71+
self.__dict__[key] = Properties(f"{path}.{key}", info["shape"])
6472

6573
@property
6674
def expr(self):
@@ -77,8 +85,10 @@ def __init__(self, path: str, shape_name: str = None):
7785
Args:
7886
path (str): The parent path of the PropertiesList instance.
7987
shape_name (str): The botocore sagemaker service model shape name.
88+
root_shape_name (str): The botocore sagemaker service model shape name.
8089
"""
8190
super(PropertiesList, self).__init__(path, shape_name)
91+
self.shape_name = shape_name
8292
self._items: Dict[Union[int, str], Properties] = dict()
8393

8494
def __getitem__(self, item: Union[int, str]):
@@ -88,7 +98,7 @@ def __getitem__(self, item: Union[int, str]):
8898
item (Union[int, str]): The index of the item in sequence.
8999
"""
90100
if item not in self._items.keys():
91-
shape = Properties._shapes.get(self._shape_name)
101+
shape = Properties._shapes.get(self.shape_name)
92102
member = shape["member"]["shape"]
93103
if isinstance(item, str):
94104
property_item = Properties(f"{self._path}['{item}']", member)

src/sagemaker/workflow/steps.py

+133
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Processor,
3131
)
3232
from sagemaker.transformer import Transformer, _TransformJob
33+
from sagemaker.tuner import HyperparameterTuner, _TuningJob
3334
from sagemaker.workflow.entities import (
3435
DefaultEnumMeta,
3536
Entity,
@@ -39,6 +40,7 @@
3940
PropertyFile,
4041
Properties,
4142
)
43+
from sagemaker.workflow.functions import Join
4244

4345

4446
class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
@@ -51,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5153
TRAINING = "Training"
5254
TRANSFORM = "Transform"
5355
CALLBACK = "Callback"
56+
TUNING = "Tuning"
5457

5558

5659
@attr.s
@@ -92,6 +95,7 @@ def add_depends_on(self, step_names: List[str]):
9295
"""Add step names to the current step depends on list"""
9396
if not step_names:
9497
return
98+
9599
if not self.depends_on:
96100
self.depends_on = []
97101
self.depends_on.extend(step_names)
@@ -429,3 +433,132 @@ def to_request(self) -> RequestType:
429433
property_file.expr for property_file in self.property_files
430434
]
431435
return request_dict
436+
437+
438+
class TuningStep(Step):
439+
"""Tuning step for workflow."""
440+
441+
def __init__(
442+
self,
443+
name: str,
444+
tuner: HyperparameterTuner,
445+
inputs=None,
446+
job_arguments: List[str] = None,
447+
cache_config: CacheConfig = None,
448+
depends_on: List[str] = None,
449+
):
450+
"""Construct a TuningStep, given a `HyperparameterTuner` instance.
451+
452+
In addition to the tuner instance, the other arguments are those that are supplied to
453+
the `fit` method of the `sagemaker.tuner.HyperparameterTuner`.
454+
455+
Args:
456+
name (str): The name of the tuning step.
457+
tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
458+
inputs: Information about the training data. Please refer to the
459+
``fit()`` method of the associated estimator, as this can take
460+
any of the following forms:
461+
462+
* (str) - The S3 location where training data is saved.
463+
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
464+
If using multiple channels for training data, you can specify
465+
a dict mapping channel names to strings or
466+
:func:`~sagemaker.inputs.TrainingInput` objects.
467+
* (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
468+
that can provide additional information about the training dataset.
469+
See :func:`sagemaker.inputs.TrainingInput` for full details.
470+
* (sagemaker.session.FileSystemInput) - channel configuration for
471+
a file system data source that can provide additional information as well as
472+
the path to the training dataset.
473+
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
474+
Amazon :class:~`Record` objects serialized and stored in S3.
475+
For use with an estimator for an Amazon algorithm.
476+
* (sagemaker.amazon.amazon_estimator.FileSystemRecordSet) -
477+
Amazon SageMaker channel configuration for a file system data source for
478+
Amazon algorithms.
479+
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
480+
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
481+
where each instance is a different channel of training data.
482+
* (list[sagemaker.amazon.amazon_estimator.FileSystemRecordSet]) - A list of
483+
:class:~`sagemaker.amazon.amazon_estimator.FileSystemRecordSet` objects,
484+
where each instance is a different channel of training data.
485+
job_arguments (List[str]): A list of strings to be passed into the processing job.
486+
Defaults to `None`.
487+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
488+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
489+
depends on
490+
"""
491+
super(TuningStep, self).__init__(name, StepTypeEnum.TUNING, depends_on)
492+
self.tuner = tuner
493+
self.inputs = inputs
494+
self.job_arguments = job_arguments
495+
self._properties = Properties(
496+
path=f"Steps.{name}",
497+
shape_names=[
498+
"DescribeHyperParameterTuningJobResponse",
499+
"ListTrainingJobsForHyperParameterTuningJobResponse",
500+
],
501+
)
502+
self.cache_config = cache_config
503+
504+
@property
505+
def arguments(self) -> RequestType:
506+
"""The arguments dict that is used to call `create_hyper_parameter_tuning_job`.
507+
508+
NOTE: The CreateHyperParameterTuningJob request is not quite the
509+
args list that workflow needs.
510+
The HyperParameterTuningJobName attribute cannot be included.
511+
"""
512+
if self.tuner.estimator is not None:
513+
self.tuner.estimator._prepare_for_training()
514+
else:
515+
for _, estimator in self.tuner.estimator_dict.items():
516+
estimator._prepare_for_training()
517+
518+
self.tuner._prepare_for_tuning()
519+
tuner_args = _TuningJob._get_tuner_args(self.tuner, self.inputs)
520+
request_dict = self.tuner.sagemaker_session._get_tuning_request(**tuner_args)
521+
request_dict.pop("HyperParameterTuningJobName")
522+
523+
return request_dict
524+
525+
@property
526+
def properties(self):
527+
"""A Properties object representing
528+
529+
`DescribeHyperParameterTuningJobResponse` and
530+
`ListTrainingJobsForHyperParameterTuningJobResponse` data model.
531+
"""
532+
return self._properties
533+
534+
def to_request(self) -> RequestType:
535+
"""Updates the dictionary with cache configuration."""
536+
request_dict = super().to_request()
537+
if self.cache_config:
538+
request_dict.update(self.cache_config.config)
539+
540+
return request_dict
541+
542+
def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = ""):
543+
"""Get the model artifact s3 uri from the top performing training jobs.
544+
545+
Args:
546+
top_k (int): the index of the top performing training job
547+
tuning step stores up to 50 top performing training jobs, hence
548+
a valid top_k value is from 0 to 49. The best training job
549+
model is at index 0
550+
s3_bucket (str): the s3 bucket to store the training job output artifact
551+
prefix (str): the s3 key prefix to store the training job output artifact
552+
"""
553+
values = ["s3:/", s3_bucket]
554+
if prefix != "" and prefix is not None:
555+
values.append(prefix)
556+
557+
return Join(
558+
on="/",
559+
values=values
560+
+ [
561+
self.properties.TrainingJobSummaries[top_k].TrainingJobName,
562+
"output/model.tar.gz",
563+
],
564+
)

tests/data/pytorch_mnist/mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def train(args):
182182
accuracy = test(model, test_loader, device)
183183
save_model(model, args.model_dir)
184184

185-
logger.debug("Overall test accuracy: {}".format(accuracy))
185+
logger.debug("Overall test accuracy: {};".format(accuracy))
186186

187187

188188
def test(model, test_loader, device):

0 commit comments

Comments
 (0)