Skip to content

Commit 6cd0cb3

Browse files
committed
support JsonGet/Join parameterization in tuning step Hyperparameter Range and Static Hyperparameters
1 parent fd7a335 commit 6cd0cb3

File tree

4 files changed

+106
-16
lines changed

4 files changed

+106
-16
lines changed

src/sagemaker/parameter.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import json
1717
from sagemaker.workflow.parameters import Parameter as PipelineParameter
18+
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
19+
from sagemaker.workflow.functions import Join as PipelineJoin
1820

1921

2022
class ParameterRange(object):
@@ -71,10 +73,10 @@ def as_tuning_range(self, name):
7173
return {
7274
"Name": name,
7375
"MinValue": str(self.min_value)
74-
if not isinstance(self.min_value, PipelineParameter)
76+
if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
7577
else self.min_value,
7678
"MaxValue": str(self.max_value)
77-
if not isinstance(self.max_value, PipelineParameter)
79+
if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
7880
else self.max_value,
7981
"ScalingType": self.scaling_type,
8082
}
@@ -108,10 +110,11 @@ def __init__(self, values): # pylint: disable=super-init-not-called
108110
values (list or object): The possible values for the hyperparameter.
109111
This input will be converted into a list of strings.
110112
"""
111-
if isinstance(values, list):
112-
self.values = [str(v) if not isinstance(v, PipelineParameter) else v for v in values]
113-
else:
114-
self.values = [str(values) if not isinstance(values, PipelineParameter) else values]
113+
values = values if isinstance(values, list) else [values]
114+
self.values = [
115+
str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v
116+
for v in values
117+
]
115118

116119
def as_tuning_range(self, name):
117120
"""Represent the parameter range as a dictionary.

src/sagemaker/tuner.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
IntegerParameter,
3939
ParameterRange,
4040
)
41+
from sagemaker.workflow.parameters import Parameter as PipelineParameter
42+
from sagemaker.workflow.functions import JsonGet as PipelineJsonGet
43+
from sagemaker.workflow.functions import Join as PipelineJoin
44+
4145
from sagemaker.session import Session
4246
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base
4347

@@ -59,6 +63,18 @@
5963
logger = logging.getLogger(__name__)
6064

6165

66+
def is_pipeline_parameters(value):
67+
"""Determine if a value is a pipeline parameter or function representation
68+
69+
Args:
70+
value (float or int): The value to be verified.
71+
72+
Returns:
73+
bool: True if it is, False otherwise.
74+
"""
75+
return isinstance(value, (PipelineParameter, PipelineJsonGet, PipelineJoin))
76+
77+
6278
class WarmStartTypes(Enum):
6379
"""Warm Start Configuration type.
6480
@@ -359,7 +375,12 @@ def _prepare_static_hyperparameters(
359375
):
360376
"""Prepare static hyperparameters for one estimator before tuning."""
361377
# Remove any hyperparameter that will be tuned
362-
static_hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
378+
static_hyperparameters = {
379+
str(k): str(v)
380+
if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin))
381+
else v
382+
for (k, v) in estimator.hyperparameters().items()
383+
}
363384
for hyperparameter_name in hyperparameter_ranges.keys():
364385
static_hyperparameters.pop(hyperparameter_name, None)
365386

tests/integ/test_workflow.py

+46-6
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@
6666
ConditionIn,
6767
ConditionLessThanOrEqualTo,
6868
)
69-
from sagemaker.workflow.condition_step import ConditionStep, JsonGet
69+
from sagemaker.workflow.condition_step import ConditionStep
7070
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
7171
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
72-
from sagemaker.workflow.properties import PropertyFile
7372
from sagemaker.wrangler.processing import DataWranglerProcessor
7473
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
7574
from sagemaker.workflow.execution_variables import ExecutionVariables
76-
from sagemaker.workflow.functions import Join
75+
from sagemaker.workflow.functions import Join, JsonGet
7776
from sagemaker.wrangler.ingestion import generate_data_ingestion_flow_from_s3_input
7877
from sagemaker.workflow.parameters import (
7978
ParameterInteger,
@@ -87,6 +86,7 @@
8786
TuningStep,
8887
TransformStep,
8988
TransformInput,
89+
PropertyFile,
9090
)
9191
from sagemaker.workflow.step_collections import RegisterModel
9292
from sagemaker.workflow.pipeline import Pipeline
@@ -137,7 +137,7 @@ def feature_store_session(sagemaker_session):
137137

138138
@pytest.fixture
139139
def pipeline_name():
140-
return f"my-pipeline-{int(time.time() * 10**7)}"
140+
return f"my-pipeline-{int(time.time() * 10 ** 7)}"
141141

142142

143143
@pytest.fixture
@@ -1371,6 +1371,8 @@ def test_tuning_multi_algos(
13711371
cpu_instance_type,
13721372
pipeline_name,
13731373
region_name,
1374+
script_dir,
1375+
athena_dataset_definition,
13741376
):
13751377
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
13761378
entry_point = os.path.join(base_dir, "mnist.py")
@@ -1382,6 +1384,42 @@ def test_tuning_multi_algos(
13821384
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
13831385
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
13841386

1387+
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
1388+
1389+
sklearn_processor = SKLearnProcessor(
1390+
framework_version="0.20.0",
1391+
instance_type=instance_type,
1392+
instance_count=instance_count,
1393+
base_job_name="test-sklearn",
1394+
sagemaker_session=sagemaker_session,
1395+
role=role,
1396+
)
1397+
1398+
property_file = PropertyFile(
1399+
name="DataAttributes", output_name="attributes", path="attributes.json"
1400+
)
1401+
1402+
step_process = ProcessingStep(
1403+
name="my-process",
1404+
display_name="ProcessingStep",
1405+
description="description for Processing step",
1406+
processor=sklearn_processor,
1407+
inputs=[
1408+
ProcessingInput(source=input_data, destination="/opt/ml/processing/input"),
1409+
ProcessingInput(dataset_definition=athena_dataset_definition),
1410+
],
1411+
outputs=[
1412+
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
1413+
ProcessingOutput(output_name="attributes", source="/opt/ml/processing/attributes.json"),
1414+
],
1415+
property_files=[property_file],
1416+
code=os.path.join(script_dir, "preprocessing.py"),
1417+
)
1418+
1419+
static_hp_1 = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
1420+
json_get_hp = JsonGet(
1421+
step_name=step_process.name, property_file=property_file, json_path="train_size"
1422+
)
13851423
pytorch_estimator = PyTorch(
13861424
entry_point=entry_point,
13871425
role=role,
@@ -1392,10 +1430,11 @@ def test_tuning_multi_algos(
13921430
sagemaker_session=sagemaker_session,
13931431
enable_sagemaker_metrics=True,
13941432
max_retry_attempts=3,
1433+
hyperparameters={"static-hp": static_hp_1, "train_size": json_get_hp},
13951434
)
13961435

13971436
min_batch_size = ParameterString(name="MinBatchSize", default_value="64")
1398-
max_batch_size = ParameterString(name="MaxBatchSize", default_value="128")
1437+
max_batch_size = json_get_hp
13991438

14001439
tuner = HyperparameterTuner.create(
14011440
estimator_dict={
@@ -1415,6 +1454,7 @@ def test_tuning_multi_algos(
14151454
"estimator-2": [{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
14161455
},
14171456
)
1457+
14181458
inputs = {
14191459
"estimator-1": TrainingInput(s3_data=input_path),
14201460
"estimator-2": TrainingInput(s3_data=input_path),
@@ -1429,7 +1469,7 @@ def test_tuning_multi_algos(
14291469
pipeline = Pipeline(
14301470
name=pipeline_name,
14311471
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
1432-
steps=[step_tune],
1472+
steps=[step_process, step_tune],
14331473
sagemaker_session=sagemaker_session,
14341474
)
14351475

tests/unit/test_tuner.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
create_transfer_learning_tuner,
3131
HyperparameterTuner,
3232
)
33+
from sagemaker.workflow.functions import JsonGet, Join
34+
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
3335

3436
from .tuner_test_utils import * # noqa: F403
3537

@@ -68,14 +70,24 @@ def tuner(estimator):
6870

6971

7072
def test_prepare_for_training(tuner):
71-
static_hyperparameters = {"validated": 1, "another_one": 0}
73+
hp1 = JsonGet(step_name="stepname", property_file="pf", json_path="jp")
74+
hp2 = Join(on="/", values=["1", "2", ParameterString(name="ps", default_value="3")])
75+
76+
static_hyperparameters = {
77+
"validated": 1,
78+
"another_one": 0,
79+
"hp1": hp1,
80+
"hp2": hp2,
81+
}
82+
7283
tuner.estimator.set_hyperparameters(**static_hyperparameters)
7384
tuner._prepare_for_tuning()
7485

7586
assert tuner._current_job_name.startswith(IMAGE_NAME)
76-
77-
assert len(tuner.static_hyperparameters) == 1
87+
assert len(tuner.static_hyperparameters) == 3
7888
assert tuner.static_hyperparameters["another_one"] == "0"
89+
assert tuner.static_hyperparameters["hp1"] == hp1
90+
assert tuner.static_hyperparameters["hp2"] == hp2
7991

8092

8193
def test_prepare_for_tuning_with_amazon_estimator(tuner, sagemaker_session):
@@ -1156,6 +1168,20 @@ def test_integer_parameter_ranges():
11561168
assert ranges["ScalingType"] == "Auto"
11571169

11581170

1171+
def test_integer_parameter_ranges_with_pipeline_parameter():
1172+
min = ParameterInteger(name="p", default_value=2)
1173+
max = JsonGet(step_name="sn", property_file="pf", json_path="jp")
1174+
scale = ParameterString(name="scale", default_value="Auto")
1175+
int_param = IntegerParameter(min, max)
1176+
ranges = int_param.as_tuning_range("some")
1177+
1178+
assert len(ranges.keys()) == 4
1179+
assert ranges["Name"] == "some"
1180+
assert ranges["MinValue"] == min
1181+
assert ranges["MaxValue"] == max
1182+
assert ranges["ScalingType"] == scale
1183+
1184+
11591185
def test_integer_parameter_scaling_type():
11601186
int_param = IntegerParameter(2, 3, scaling_type="Linear")
11611187
int_range = int_param.as_tuning_range("range")

0 commit comments

Comments
 (0)