Skip to content

Commit d9d8c68

Browse files
authored
feat: Script mode support for Estimator class (aws#2834)
1 parent 167b723 commit d9d8c68

File tree

9 files changed

+548
-124
lines changed

9 files changed

+548
-124
lines changed

src/sagemaker/algorithm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
self.validate_train_spec()
175175
self.hyperparameter_definitions = self._parse_hyperparameters()
176176

177-
self.hyperparam_dict = {}
177+
self._hyperparameters = {}
178178
if hyperparameters:
179179
self.set_hyperparameters(**hyperparameters)
180180

@@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
215215
"""Placeholder docstring"""
216216
for k, v in kwargs.items():
217217
value = self._validate_and_cast_hyperparameter(k, v)
218-
self.hyperparam_dict[k] = value
218+
self._hyperparameters[k] = value
219219

220220
self._validate_and_set_default_hyperparameters()
221221

@@ -225,7 +225,7 @@ def hyperparameters(self):
225225
The fit() method, that does the model training, calls this method to
226226
find the hyperparameters you specified.
227227
"""
228-
return self.hyperparam_dict
228+
return self._hyperparameters
229229

230230
def training_image_uri(self):
231231
"""Returns the docker image to use for training.
@@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
464464
# Check if all the required hyperparameters are set. If there is a default value
465465
# for one, set it.
466466
for name, definition in self.hyperparameter_definitions.items():
467-
if name not in self.hyperparam_dict:
467+
if name not in self._hyperparameters:
468468
spec = definition["spec"]
469469
if "DefaultValue" in spec:
470-
self.hyperparam_dict[name] = spec["DefaultValue"]
470+
self._hyperparameters[name] = spec["DefaultValue"]
471471
elif "IsRequired" in spec and spec["IsRequired"]:
472472
raise ValueError("Required hyperparameter: %s is not set" % name)
473473

src/sagemaker/chainer/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717

18-
from sagemaker.estimator import Framework
18+
from sagemaker.estimator import Framework, EstimatorBase
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
@@ -158,7 +158,9 @@ def hyperparameters(self):
158158

159159
# remove unset keys.
160160
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
161-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
161+
hyperparameters.update(
162+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
163+
)
162164
return hyperparameters
163165

164166
def create_model(

src/sagemaker/estimator.py

Lines changed: 409 additions & 102 deletions
Large diffs are not rendered by default.

src/sagemaker/huggingface/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import re
1818

1919
from sagemaker.deprecations import renamed_kwargs
20-
from sagemaker.estimator import Framework
20+
from sagemaker.estimator import Framework, EstimatorBase
2121
from sagemaker.fw_utils import (
2222
framework_name_from_image,
2323
warn_if_parameter_server_with_multi_gpu,
@@ -246,13 +246,13 @@ def hyperparameters(self):
246246
distribution=self.distribution
247247
)
248248
hyperparameters.update(
249-
Framework._json_encode_hyperparameters(distributed_training_hyperparameters)
249+
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
250250
)
251251

252252
if self.compiler_config:
253253
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
254254
hyperparameters.update(
255-
Framework._json_encode_hyperparameters(training_compiler_hyperparameters)
255+
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
256256
)
257257

258258
return hyperparameters

src/sagemaker/model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,12 +399,10 @@ def prepare_container_def(
399399
)
400400
deploy_env = copy.deepcopy(self.env)
401401
if self.source_dir or self.dependencies or self.entry_point or self.git_config:
402-
if self.key_prefix or self.git_config:
403-
self._upload_code(deploy_key_prefix, repack=False)
404-
elif self.source_dir and self.entry_point:
405-
self._upload_code(deploy_key_prefix, repack=True)
406-
else:
407-
self._upload_code(deploy_key_prefix, repack=False)
402+
is_repack = (
403+
self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
404+
)
405+
self._upload_code(deploy_key_prefix, repack=is_repack)
408406
deploy_env.update(self._script_mode_env_vars())
409407
return sagemaker.container_def(
410408
self.image_uri, self.model_data, deploy_env, image_config=self.image_config

src/sagemaker/pytorch/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from packaging.version import Version
1919

2020
from sagemaker.deprecations import renamed_kwargs
21-
from sagemaker.estimator import Framework
21+
from sagemaker.estimator import Framework, EstimatorBase
2222
from sagemaker.fw_utils import (
2323
framework_name_from_image,
2424
framework_version_from_tag,
@@ -192,7 +192,9 @@ def hyperparameters(self):
192192
additional_hyperparameters = self._distribution_configuration(
193193
distribution=self.distribution
194194
)
195-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
195+
hyperparameters.update(
196+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
197+
)
196198
return hyperparameters
197199

198200
def create_model(

src/sagemaker/rl/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import re
1919

2020
from sagemaker import image_uris, fw_utils
21-
from sagemaker.estimator import Framework
21+
from sagemaker.estimator import Framework, EstimatorBase
2222
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
2323
from sagemaker.mxnet.model import MXNetModel
2424
from sagemaker.tensorflow.model import TensorFlowModel
@@ -340,7 +340,9 @@ def hyperparameters(self):
340340
SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE,
341341
}
342342

343-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
343+
hyperparameters.update(
344+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
345+
)
344346
return hyperparameters
345347

346348
@classmethod

src/sagemaker/tensorflow/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker import image_uris, s3, utils
2121
from sagemaker.deprecations import renamed_kwargs
22-
from sagemaker.estimator import Framework
22+
from sagemaker.estimator import Framework, EstimatorBase
2323
import sagemaker.fw_utils as fw
2424
from sagemaker.tensorflow import defaults
2525
from sagemaker.tensorflow.model import TensorFlowModel
@@ -327,7 +327,9 @@ def hyperparameters(self):
327327
)
328328
additional_hyperparameters["model_dir"] = self.model_dir
329329

330-
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
330+
hyperparameters.update(
331+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
332+
)
331333
return hyperparameters
332334

333335
def _default_s3_path(self, directory, mpi=False):

tests/unit/test_estimator.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import os
1818
import subprocess
1919
from time import sleep
20+
from sagemaker.fw_utils import UploadedCode
21+
2022

2123
import pytest
2224
from botocore.exceptions import ClientError
@@ -3350,3 +3352,112 @@ def test_image_name_map(sagemaker_session):
33503352
)
33513353

33523354
assert e.image_uri == IMAGE_URI
3355+
3356+
3357+
@patch("sagemaker.git_utils.git_clone_repo")
3358+
def test_git_support_with_branch_and_commit_succeed_estimator_class(
3359+
git_clone_repo, sagemaker_session
3360+
):
3361+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: {
3362+
"entry_point": "/tmp/repo_dir/entry_point",
3363+
"source_dir": None,
3364+
"dependencies": None,
3365+
}
3366+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
3367+
entry_point = "entry_point"
3368+
fw = Estimator(
3369+
entry_point=entry_point,
3370+
git_config=git_config,
3371+
role=ROLE,
3372+
sagemaker_session=sagemaker_session,
3373+
instance_count=INSTANCE_COUNT,
3374+
instance_type=INSTANCE_TYPE,
3375+
image_uri=IMAGE_URI,
3376+
)
3377+
fw.fit()
3378+
git_clone_repo.assert_called_once_with(git_config, entry_point, None, None)
3379+
3380+
3381+
@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3")
3382+
def test_script_mode_estimator(patched_stage_user_code, sagemaker_session):
3383+
patched_stage_user_code.return_value = UploadedCode(
3384+
s3_prefix="s3://bucket/key", script_name="script_name"
3385+
)
3386+
script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"
3387+
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
3388+
model_uri = "s3://someprefix2/models/model.tar.gz"
3389+
t = Estimator(
3390+
entry_point=SCRIPT_PATH,
3391+
role=ROLE,
3392+
sagemaker_session=sagemaker_session,
3393+
instance_count=INSTANCE_COUNT,
3394+
instance_type=INSTANCE_TYPE,
3395+
source_dir=script_uri,
3396+
image_uri=image_uri,
3397+
model_uri=model_uri,
3398+
)
3399+
t.fit("s3://bucket/mydata")
3400+
3401+
patched_stage_user_code.assert_called_once()
3402+
sagemaker_session.train.assert_called_once()
3403+
3404+
3405+
@patch("time.time", return_value=TIME)
3406+
@patch("sagemaker.estimator.tar_and_upload_dir")
3407+
def test_script_mode_estimator_same_calls_as_framework(
3408+
patched_tar_and_upload_dir, sagemaker_session
3409+
):
3410+
3411+
patched_tar_and_upload_dir.return_value = UploadedCode(
3412+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
3413+
)
3414+
sagemaker_session.boto_region_name = REGION
3415+
3416+
script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"
3417+
3418+
instance_type = "ml.p2.xlarge"
3419+
instance_count = 1
3420+
3421+
model_uri = "s3://someprefix2/models/model.tar.gz"
3422+
training_data_uri = "s3://bucket/mydata"
3423+
3424+
generic_estimator = Estimator(
3425+
entry_point=SCRIPT_PATH,
3426+
role=ROLE,
3427+
region=REGION,
3428+
sagemaker_session=sagemaker_session,
3429+
instance_count=instance_count,
3430+
instance_type=instance_type,
3431+
source_dir=script_uri,
3432+
image_uri=IMAGE_URI,
3433+
model_uri=model_uri,
3434+
environment={"USE_SMDEBUG": "0"},
3435+
dependencies=[],
3436+
debugger_hook_config={},
3437+
)
3438+
generic_estimator.fit(training_data_uri)
3439+
3440+
generic_estimator_tar_and_upload_dir_args = patched_tar_and_upload_dir.call_args_list
3441+
generic_estimator_train_args = sagemaker_session.train.call_args_list
3442+
3443+
patched_tar_and_upload_dir.reset_mock()
3444+
sagemaker_session.train.reset_mock()
3445+
3446+
framework_estimator = DummyFramework(
3447+
entry_point=SCRIPT_PATH,
3448+
role=ROLE,
3449+
region=REGION,
3450+
source_dir=script_uri,
3451+
instance_count=instance_count,
3452+
instance_type=instance_type,
3453+
sagemaker_session=sagemaker_session,
3454+
model_uri=model_uri,
3455+
dependencies=[],
3456+
debugger_hook_config={},
3457+
)
3458+
framework_estimator.fit(training_data_uri)
3459+
3460+
assert len(generic_estimator_tar_and_upload_dir_args) == 1
3461+
assert len(generic_estimator_train_args) == 1
3462+
assert generic_estimator_tar_and_upload_dir_args == patched_tar_and_upload_dir.call_args_list
3463+
assert generic_estimator_train_args == sagemaker_session.train.call_args_list

0 commit comments

Comments
 (0)