Skip to content

Commit f27d5a7

Browse files
committed
feat: script mode support for estimator class
1 parent b09793a commit f27d5a7

File tree

8 files changed

+518
-117
lines changed

8 files changed

+518
-117
lines changed

src/sagemaker/algorithm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
self.validate_train_spec()
175175
self.hyperparameter_definitions = self._parse_hyperparameters()
176176

177-
self.hyperparam_dict = {}
177+
self._hyperparameters = {}
178178
if hyperparameters:
179179
self.set_hyperparameters(**hyperparameters)
180180

@@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
215215
"""Placeholder docstring"""
216216
for k, v in kwargs.items():
217217
value = self._validate_and_cast_hyperparameter(k, v)
218-
self.hyperparam_dict[k] = value
218+
self._hyperparameters[k] = value
219219

220220
self._validate_and_set_default_hyperparameters()
221221

@@ -225,7 +225,7 @@ def hyperparameters(self):
225225
The fit() method, that does the model training, calls this method to
226226
find the hyperparameters you specified.
227227
"""
228-
return self.hyperparam_dict
228+
return self._hyperparameters
229229

230230
def training_image_uri(self):
231231
"""Returns the docker image to use for training.
@@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
464464
# Check if all the required hyperparameters are set. If there is a default value
465465
# for one, set it.
466466
for name, definition in self.hyperparameter_definitions.items():
467-
if name not in self.hyperparam_dict:
467+
if name not in self._hyperparameters:
468468
spec = definition["spec"]
469469
if "DefaultValue" in spec:
470-
self.hyperparam_dict[name] = spec["DefaultValue"]
470+
self._hyperparameters[name] = spec["DefaultValue"]
471471
elif "IsRequired" in spec and spec["IsRequired"]:
472472
raise ValueError("Required hyperparameter: %s is not set" % name)
473473

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717

18-
from sagemaker.estimator import Framework
18+
from sagemaker.estimator import Framework, EstimatorBase
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
@@ -158,7 +158,9 @@ def hyperparameters(self):
158158

159159
# remove unset keys.
160160
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
161-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
161+
hyperparameters.update(
162+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
163+
)
162164
return hyperparameters
163165

164166
def create_model(

src/sagemaker/estimator.py

Lines changed: 408 additions & 101 deletions
Large diffs are not rendered by default.

src/sagemaker/huggingface/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818

1919
from sagemaker.deprecations import renamed_kwargs
20-
from sagemaker.estimator import Framework
20+
from sagemaker.estimator import Framework, EstimatorBase
2121
from sagemaker.fw_utils import (
2222
framework_name_from_image,
2323
warn_if_parameter_server_with_multi_gpu,
@@ -246,13 +246,13 @@ def hyperparameters(self):
246246
distribution=self.distribution
247247
)
248248
hyperparameters.update(
249-
Framework._json_encode_hyperparameters(distributed_training_hyperparameters)
249+
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
250250
)
251251

252252
if self.compiler_config:
253253
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
254254
hyperparameters.update(
255-
Framework._json_encode_hyperparameters(training_compiler_hyperparameters)
255+
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
256256
)
257257

258258
return hyperparameters

src/sagemaker/pytorch/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from packaging.version import Version
1919

2020
from sagemaker.deprecations import renamed_kwargs
21-
from sagemaker.estimator import Framework
21+
from sagemaker.estimator import Framework, EstimatorBase
2222
from sagemaker.fw_utils import (
2323
framework_name_from_image,
2424
framework_version_from_tag,
@@ -192,7 +192,9 @@ def hyperparameters(self):
192192
additional_hyperparameters = self._distribution_configuration(
193193
distribution=self.distribution
194194
)
195-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
195+
hyperparameters.update(
196+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
197+
)
196198
return hyperparameters
197199

198200
def create_model(

src/sagemaker/rl/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import re
1919

2020
from sagemaker import image_uris, fw_utils
21-
from sagemaker.estimator import Framework
21+
from sagemaker.estimator import Framework, EstimatorBase
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
2424
from sagemaker.tensorflow.model import TensorFlowModel
@@ -340,7 +340,9 @@ def hyperparameters(self):
340340
SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE,
341341
}
342342

343-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
343+
hyperparameters.update(
344+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
345+
)
344346
return hyperparameters
345347

346348
@classmethod

src/sagemaker/tensorflow/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker import image_uris, s3, utils
2121
from sagemaker.deprecations import renamed_kwargs
22-
from sagemaker.estimator import Framework
22+
from sagemaker.estimator import Framework, EstimatorBase
2323
import sagemaker.fw_utils as fw
2424
from sagemaker.tensorflow import defaults
2525
from sagemaker.tensorflow.model import TensorFlowModel
@@ -327,7 +327,9 @@ def hyperparameters(self):
327327
)
328328
additional_hyperparameters["model_dir"] = self.model_dir
329329

330-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
330+
hyperparameters.update(
331+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
332+
)
331333
return hyperparameters
332334

333335
def _default_s3_path(self, directory, mpi=False):

tests/unit/test_estimator.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import os
1818
import subprocess
1919
from time import sleep
20+
from sagemaker.fw_utils import UploadedCode
21+
2022

2123
import pytest
2224
from botocore.exceptions import ClientError
@@ -3350,3 +3352,87 @@ def test_image_name_map(sagemaker_session):
33503352
)
33513353

33523354
assert e.image_uri == IMAGE_URI
3355+
3356+
3357+
@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3")
3358+
def test_script_mode_estimator(patched_stage_user_code, sagemaker_session):
3359+
patched_stage_user_code.return_value = UploadedCode(
3360+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
3361+
)
3362+
script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"
3363+
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
3364+
model_uri = "s3://someprefix2/models/model.tar.gz"
3365+
t = Estimator(
3366+
entry_point=SCRIPT_PATH,
3367+
role=ROLE,
3368+
sagemaker_session=sagemaker_session,
3369+
instance_count=INSTANCE_COUNT,
3370+
instance_type=INSTANCE_TYPE,
3371+
source_dir=script_uri,
3372+
image_uri=image_uri,
3373+
model_uri=model_uri,
3374+
)
3375+
t.fit("s3://bucket/mydata")
3376+
3377+
patched_stage_user_code.assert_called_once()
3378+
sagemaker_session.train.assert_called_once()
3379+
3380+
3381+
@patch("time.time", return_value=TIME)
3382+
@patch("sagemaker.estimator.tar_and_upload_dir")
3383+
def test_script_mode_estimator_same_calls_as_framework(
3384+
patched_tar_and_upload_dir, sagemaker_session
3385+
):
3386+
3387+
patched_tar_and_upload_dir.return_value = UploadedCode(
3388+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
3389+
)
3390+
sagemaker_session.boto_region_name = REGION
3391+
3392+
script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"
3393+
3394+
instance_type = "ml.p2.xlarge"
3395+
instance_count = 1
3396+
3397+
model_uri = "s3://someprefix2/models/model.tar.gz"
3398+
training_data_uri = "s3://bucket/mydata"
3399+
3400+
generic_estimator = Estimator(
3401+
entry_point=SCRIPT_PATH,
3402+
role=ROLE,
3403+
region=REGION,
3404+
sagemaker_session=sagemaker_session,
3405+
instance_count=instance_count,
3406+
instance_type=instance_type,
3407+
source_dir=script_uri,
3408+
image_uri=IMAGE_URI,
3409+
model_uri=model_uri,
3410+
environment={"USE_SMDEBUG": "0"},
3411+
dependencies=[],
3412+
debugger_hook_config={},
3413+
)
3414+
generic_estimator.fit(training_data_uri)
3415+
3416+
generic_estimator_tar_and_upload_dir_args = patched_tar_and_upload_dir.call_args_list
3417+
generic_estimator_train_args = sagemaker_session.train.call_args_list
3418+
3419+
patched_tar_and_upload_dir.reset_mock()
3420+
sagemaker_session.train.reset_mock()
3421+
3422+
framework_estimator = DummyFramework(
3423+
entry_point=SCRIPT_PATH,
3424+
role=ROLE,
3425+
region=REGION,
3426+
source_dir=script_uri,
3427+
instance_count=instance_count,
3428+
instance_type=instance_type,
3429+
sagemaker_session=sagemaker_session,
3430+
model_uri=model_uri,
3431+
dependencies=[],
3432+
debugger_hook_config={},
3433+
)
3434+
framework_estimator.fit(training_data_uri)
3435+
3436+
assert len(generic_estimator_tar_and_upload_dir_args) == 1
3437+
assert generic_estimator_tar_and_upload_dir_args == patched_tar_and_upload_dir.call_args_list
3438+
assert generic_estimator_train_args == sagemaker_session.train.call_args_list

0 commit comments

Comments
 (0)