Skip to content

Commit fecc3c8

Browse files
authored
Merge branch 'master' into eia_151
2 parents 3ac5bf8 + e521b87 commit fecc3c8

File tree

11 files changed

+175
-16
lines changed

11 files changed

+175
-16
lines changed

CHANGELOG.md

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
## v2.53.0 (2021-08-12)
4+
5+
### Features
6+
7+
* support tuning step parameter range parameterization + support retry strategy in tuner
8+
9+
## v2.52.2.post0 (2021-08-11)
10+
11+
### Documentation Changes
12+
13+
* clarify that default_bucket creates a bucket
14+
* Minor updates to Clarify API documentation
15+
316
## v2.52.2 (2021-08-10)
417

518
### Bug Fixes and Other Changes

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.52.3.dev0
1+
2.53.1.dev0

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ ConditionStep
55
-------------
66

77
.. autoclass:: sagemaker.workflow.condition_step.ConditionStep
8-
98
.. deprecated:: sagemaker.workflow.condition_step.JsonGet
109

1110
Conditions

src/sagemaker/local/local_session.py

+15
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,21 @@ def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
571571
# on local mode.
572572
pass # pylint: disable=unnecessary-pass
573573

574+
def logs_for_processing_job(self, job_name, wait=False, poll=10):
575+
"""A no-op method meant to override the sagemaker client.
576+
577+
Args:
578+
job_name:
579+
wait: (Default value = False)
580+
poll: (Default value = 10)
581+
582+
Returns:
583+
584+
"""
585+
# override logs_for_job() as it doesn't need to perform any action
586+
# on local mode.
587+
pass # pylint: disable=unnecessary-pass
588+
574589

575590
class file_input(object):
576591
"""Amazon SageMaker channel configuration for FILE data sources, used in local mode."""

src/sagemaker/parameter.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
15+
1516
import json
17+
from sagemaker.workflow.parameters import Parameter as PipelineParameter
1618

1719

1820
class ParameterRange(object):
@@ -68,8 +70,12 @@ def as_tuning_range(self, name):
6870
"""
6971
return {
7072
"Name": name,
71-
"MinValue": str(self.min_value),
72-
"MaxValue": str(self.max_value),
73+
"MinValue": str(self.min_value)
74+
if not isinstance(self.min_value, PipelineParameter)
75+
else self.min_value,
76+
"MaxValue": str(self.max_value)
77+
if not isinstance(self.max_value, PipelineParameter)
78+
else self.max_value,
7379
"ScalingType": self.scaling_type,
7480
}
7581

@@ -103,9 +109,9 @@ def __init__(self, values): # pylint: disable=super-init-not-called
103109
This input will be converted into a list of strings.
104110
"""
105111
if isinstance(values, list):
106-
self.values = [str(v) for v in values]
112+
self.values = [str(v) if not isinstance(v, PipelineParameter) else v for v in values]
107113
else:
108-
self.values = [str(values)]
114+
self.values = [str(values) if not isinstance(values, PipelineParameter) else values]
109115

110116
def as_tuning_range(self, name):
111117
"""Represent the parameter range as a dictionary.

src/sagemaker/session.py

+6
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ def list_s3_files(self, bucket, key_prefix):
357357
def default_bucket(self):
358358
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
359359
360+
This function will create the s3 bucket if it does not exist.
361+
360362
Returns:
361363
str: The name of the default bucket, which is of the form:
362364
``sagemaker-{region}-{AWS account ID}``.
@@ -2211,6 +2213,7 @@ def _map_training_config(
22112213
use_spot_instances=False,
22122214
checkpoint_s3_uri=None,
22132215
checkpoint_local_path=None,
2216+
max_retry_attempts=None,
22142217
):
22152218
"""Construct a dictionary of training job configuration from the arguments.
22162219
@@ -2264,6 +2267,7 @@ def _map_training_config(
22642267
objective_metric_name (str): Name of the metric for evaluating training jobs.
22652268
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
22662269
be one of three types: Continuous, Integer, or Categorical.
2270+
max_retry_attempts (int): The number of times to retry the job.
22672271
22682272
Returns:
22692273
A dictionary of training job configuration. For format details, please refer to
@@ -2320,6 +2324,8 @@ def _map_training_config(
23202324
if parameter_ranges is not None:
23212325
training_job_definition["HyperParameterRanges"] = parameter_ranges
23222326

2327+
if max_retry_attempts is not None:
2328+
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}
23232329
return training_job_definition
23242330

23252331
def stop_tuning_job(self, name):

src/sagemaker/tuner.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,10 @@ def _get_tuner_args(cls, tuner, inputs):
15071507

15081508
if tuner.estimator is not None:
15091509
tuner_args["training_config"] = cls._prepare_training_config(
1510-
inputs, tuner.estimator, tuner.static_hyperparameters, tuner.metric_definitions
1510+
inputs=inputs,
1511+
estimator=tuner.estimator,
1512+
static_hyperparameters=tuner.static_hyperparameters,
1513+
metric_definitions=tuner.metric_definitions,
15111514
)
15121515

15131516
if tuner.estimator_dict is not None:
@@ -1580,6 +1583,9 @@ def _prepare_training_config(
15801583
if parameter_ranges is not None:
15811584
training_config["parameter_ranges"] = parameter_ranges
15821585

1586+
if estimator.max_retry_attempts is not None:
1587+
training_config["max_retry_attempts"] = estimator.max_retry_attempts
1588+
15831589
return training_config
15841590

15851591
def stop(self):

src/sagemaker/workflow/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def _interpolate(
320320
"""
321321
if isinstance(obj, (Expression, Parameter, Properties)):
322322
return obj.expr
323+
323324
if isinstance(obj, CallbackOutput):
324325
step_name = callback_output_to_step_map[obj.output_name]
325326
return obj.expr(step_name)

tests/integ/test_workflow.py

+95-5
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def test_conditional_pytorch_training_model_registration(
10751075
pass
10761076

10771077

1078-
def test_tuning(
1078+
def test_tuning_single_algo(
10791079
sagemaker_session,
10801080
role,
10811081
cpu_instance_type,
@@ -1098,14 +1098,17 @@ def test_tuning(
10981098
role=role,
10991099
framework_version="1.5.0",
11001100
py_version="py3",
1101-
instance_count=1,
1102-
instance_type="ml.m5.xlarge",
1101+
instance_count=instance_count,
1102+
instance_type=instance_type,
11031103
sagemaker_session=sagemaker_session,
11041104
enable_sagemaker_metrics=True,
1105+
max_retry_attempts=3,
11051106
)
11061107

1108+
min_batch_size = ParameterString(name="MinBatchSize", default_value="64")
1109+
max_batch_size = ParameterString(name="MaxBatchSize", default_value="128")
11071110
hyperparameter_ranges = {
1108-
"batch-size": IntegerParameter(64, 128),
1111+
"batch-size": IntegerParameter(min_batch_size, max_batch_size),
11091112
}
11101113

11111114
tuner = HyperparameterTuner(
@@ -1161,7 +1164,7 @@ def test_tuning(
11611164

11621165
pipeline = Pipeline(
11631166
name=pipeline_name,
1164-
parameters=[instance_count, instance_type],
1167+
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
11651168
steps=[step_tune, step_best_model, step_second_best_model],
11661169
sagemaker_session=sagemaker_session,
11671170
)
@@ -1185,6 +1188,93 @@ def test_tuning(
11851188
pass
11861189

11871190

1191+
def test_tuning_multi_algos(
1192+
sagemaker_session,
1193+
role,
1194+
cpu_instance_type,
1195+
pipeline_name,
1196+
region_name,
1197+
):
1198+
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
1199+
entry_point = os.path.join(base_dir, "mnist.py")
1200+
input_path = sagemaker_session.upload_data(
1201+
path=os.path.join(base_dir, "training"),
1202+
key_prefix="integ-test-data/pytorch_mnist/training",
1203+
)
1204+
1205+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
1206+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
1207+
1208+
pytorch_estimator = PyTorch(
1209+
entry_point=entry_point,
1210+
role=role,
1211+
framework_version="1.5.0",
1212+
py_version="py3",
1213+
instance_count=instance_count,
1214+
instance_type=instance_type,
1215+
sagemaker_session=sagemaker_session,
1216+
enable_sagemaker_metrics=True,
1217+
max_retry_attempts=3,
1218+
)
1219+
1220+
min_batch_size = ParameterString(name="MinBatchSize", default_value="64")
1221+
max_batch_size = ParameterString(name="MaxBatchSize", default_value="128")
1222+
1223+
tuner = HyperparameterTuner.create(
1224+
estimator_dict={
1225+
"estimator-1": pytorch_estimator,
1226+
"estimator-2": pytorch_estimator,
1227+
},
1228+
objective_metric_name_dict={
1229+
"estimator-1": "test:acc",
1230+
"estimator-2": "test:acc",
1231+
},
1232+
hyperparameter_ranges_dict={
1233+
"estimator-1": {"batch-size": IntegerParameter(min_batch_size, max_batch_size)},
1234+
"estimator-2": {"batch-size": IntegerParameter(min_batch_size, max_batch_size)},
1235+
},
1236+
metric_definitions_dict={
1237+
"estimator-1": [{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
1238+
"estimator-2": [{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
1239+
},
1240+
)
1241+
inputs = {
1242+
"estimator-1": TrainingInput(s3_data=input_path),
1243+
"estimator-2": TrainingInput(s3_data=input_path),
1244+
}
1245+
1246+
step_tune = TuningStep(
1247+
name="my-tuning-step",
1248+
tuner=tuner,
1249+
inputs=inputs,
1250+
)
1251+
1252+
pipeline = Pipeline(
1253+
name=pipeline_name,
1254+
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
1255+
steps=[step_tune],
1256+
sagemaker_session=sagemaker_session,
1257+
)
1258+
1259+
try:
1260+
response = pipeline.create(role)
1261+
create_arn = response["PipelineArn"]
1262+
assert re.match(
1263+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
1264+
)
1265+
1266+
execution = pipeline.start(parameters={})
1267+
assert re.match(
1268+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
1269+
execution.arn,
1270+
)
1271+
finally:
1272+
try:
1273+
pipeline.delete()
1274+
except Exception:
1275+
pass
1276+
1277+
11881278
def test_mxnet_model_registration(
11891279
sagemaker_session,
11901280
role,

tests/unit/sagemaker/workflow/test_steps.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -716,14 +716,16 @@ def test_multi_algo_tuning_step(sagemaker_session):
716716
data_source_uri_parameter = ParameterString(
717717
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
718718
)
719+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
719720
estimator = Estimator(
720721
image_uri=IMAGE_URI,
721722
role=ROLE,
722-
instance_count=1,
723+
instance_count=instance_count,
723724
instance_type="ml.c5.4xlarge",
724725
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
725726
rules=[],
726727
sagemaker_session=sagemaker_session,
728+
max_retry_attempts=10,
727729
)
728730

729731
estimator.set_hyperparameters(
@@ -739,8 +741,9 @@ def test_multi_algo_tuning_step(sagemaker_session):
739741
augmentation_type="crop",
740742
)
741743

744+
initial_lr_param = ParameterString(name="InitialLR", default_value="0.0001")
742745
hyperparameter_ranges = {
743-
"learning_rate": ContinuousParameter(0.0001, 0.05),
746+
"learning_rate": ContinuousParameter(initial_lr_param, 0.05),
744747
"momentum": ContinuousParameter(0.0, 0.99),
745748
"weight_decay": ContinuousParameter(0.0, 0.99),
746749
}
@@ -825,7 +828,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
825828
"ContinuousParameterRanges": [
826829
{
827830
"Name": "learning_rate",
828-
"MinValue": "0.0001",
831+
"MinValue": initial_lr_param,
829832
"MaxValue": "0.05",
830833
"ScalingType": "Auto",
831834
},
@@ -845,6 +848,9 @@ def test_multi_algo_tuning_step(sagemaker_session):
845848
"CategoricalParameterRanges": [],
846849
"IntegerParameterRanges": [],
847850
},
851+
"RetryStrategy": {
852+
"MaximumRetryAttempts": 10,
853+
},
848854
},
849855
{
850856
"StaticHyperParameters": {
@@ -889,7 +895,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
889895
"ContinuousParameterRanges": [
890896
{
891897
"Name": "learning_rate",
892-
"MinValue": "0.0001",
898+
"MinValue": initial_lr_param,
893899
"MaxValue": "0.05",
894900
"ScalingType": "Auto",
895901
},
@@ -909,6 +915,9 @@ def test_multi_algo_tuning_step(sagemaker_session):
909915
"CategoricalParameterRanges": [],
910916
"IntegerParameterRanges": [],
911917
},
918+
"RetryStrategy": {
919+
"MaximumRetryAttempts": 10,
920+
},
912921
},
913922
],
914923
},

tests/unit/test_local_session.py

+14
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,20 @@ def test_describe_transform_job_does_not_exist(LocalSession, _LocalTransformJob)
551551
local_sagemaker_client.describe_transform_job("transform-job-does-not-exist")
552552

553553

554+
@patch("sagemaker.local.image._SageMakerContainer.process")
555+
@patch("sagemaker.local.local_session.LocalSession")
556+
def test_logs_for_job(process, LocalSession):
557+
local_job_logs = LocalSession.logs_for_job("my-processing-job")
558+
assert local_job_logs is not None
559+
560+
561+
@patch("sagemaker.local.image._SageMakerContainer.process")
562+
@patch("sagemaker.local.local_session.LocalSession")
563+
def test_logs_for_processing_job(process, LocalSession):
564+
local_processing_job_logs = LocalSession.logs_for_processing_job("my-processing-job")
565+
assert local_processing_job_logs is not None
566+
567+
554568
@patch("sagemaker.local.local_session.LocalSession")
555569
def test_describe_endpoint_config(LocalSession):
556570
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

0 commit comments

Comments
 (0)