Skip to content

Commit ace07d7

Browse files
authored
fix: estimator hyperparameters in script mode (aws#3344)
1 parent e611816 commit ace07d7

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,6 +2629,8 @@ def __init__(
26292629
**kwargs,
26302630
)
26312631

2632+
self.set_hyperparameters(**self._hyperparameters)
2633+
26322634
def training_image_uri(self):
26332635
"""Returns the docker image to use for training.
26342636
@@ -2644,9 +2646,15 @@ def set_hyperparameters(self, **kwargs):
26442646
training code on SageMaker. For convenience, this accepts other types
26452647
for keys and values, but ``str()`` will be called to convert them before
26462648
training.
2649+
2650+
If a source directory is specified, this method escapes the dict argument as JSON,
2651+
and updates the private hyperparameter attribute.
26472652
"""
2648-
for k, v in kwargs.items():
2649-
self._hyperparameters[k] = v
2653+
if self.source_dir:
2654+
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs))
2655+
else:
2656+
for k, v in kwargs.items():
2657+
self._hyperparameters[k] = v
26502658

26512659
def hyperparameters(self):
26522660
"""Returns the hyperparameters as a dictionary to use for training.

tests/unit/test_estimator.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
from copy import deepcopy
1415

1516
import logging
1617
import json
@@ -3825,6 +3826,12 @@ def test_script_mode_estimator_same_calls_as_framework(
38253826

38263827
model_uri = "s3://someprefix2/models/model.tar.gz"
38273828
training_data_uri = "s3://bucket/mydata"
3829+
hyperparameters = {
3830+
"int_hyperparam": 1,
3831+
"string_hyperparam": "hello",
3832+
"stringified_numeric_hyperparam": "44",
3833+
"float_hyperparam": 1.234,
3834+
}
38283835

38293836
generic_estimator = Estimator(
38303837
entry_point=SCRIPT_PATH,
@@ -3838,6 +3845,7 @@ def test_script_mode_estimator_same_calls_as_framework(
38383845
model_uri=model_uri,
38393846
dependencies=[],
38403847
debugger_hook_config={},
3848+
hyperparameters=deepcopy(hyperparameters),
38413849
)
38423850
generic_estimator.fit(training_data_uri)
38433851

@@ -3858,6 +3866,7 @@ def test_script_mode_estimator_same_calls_as_framework(
38583866
model_uri=model_uri,
38593867
dependencies=[],
38603868
debugger_hook_config={},
3869+
hyperparameters=deepcopy(hyperparameters),
38613870
)
38623871
framework_estimator.fit(training_data_uri)
38633872

@@ -4394,3 +4403,51 @@ def test_insert_invalid_source_code_args():
43944403
assert (
43954404
"The entry_point should not be a pipeline variable " "when source_dir is a local path"
43964405
) in str(err.value)
4406+
4407+
4408+
@patch("time.time", return_value=TIME)
4409+
@patch("sagemaker.estimator.tar_and_upload_dir")
4410+
@patch("sagemaker.model.Model._upload_code")
4411+
def test_script_mode_estimator_escapes_hyperparameters_as_json(
4412+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
4413+
):
4414+
patched_tar_and_upload_dir.return_value = UploadedCode(
4415+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
4416+
)
4417+
sagemaker_session.boto_region_name = REGION
4418+
4419+
instance_type = "ml.p2.xlarge"
4420+
instance_count = 1
4421+
4422+
training_data_uri = "s3://bucket/mydata"
4423+
4424+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
4425+
4426+
hyperparameters = {
4427+
"int_hyperparam": 1,
4428+
"string_hyperparam": "hello",
4429+
"stringified_numeric_hyperparam": "44",
4430+
"float_hyperparam": 1.234,
4431+
}
4432+
4433+
generic_estimator = Estimator(
4434+
entry_point=SCRIPT_PATH,
4435+
role=ROLE,
4436+
region=REGION,
4437+
sagemaker_session=sagemaker_session,
4438+
instance_count=instance_count,
4439+
instance_type=instance_type,
4440+
source_dir=jumpstart_source_dir,
4441+
image_uri=IMAGE_URI,
4442+
model_uri=MODEL_DATA,
4443+
hyperparameters=hyperparameters,
4444+
)
4445+
generic_estimator.fit(training_data_uri)
4446+
4447+
formatted_hyperparams = EstimatorBase._json_encode_hyperparameters(hyperparameters)
4448+
4449+
assert (
4450+
set(formatted_hyperparams.items())
4451+
- set(sagemaker_session.train.call_args_list[0][1]["hyperparameters"].items())
4452+
== set()
4453+
)

0 commit comments

Comments
 (0)