Skip to content

feat: Script mode support for Estimator class #2834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self.validate_train_spec()
self.hyperparameter_definitions = self._parse_hyperparameters()

self.hyperparam_dict = {}
self._hyperparameters = {}
if hyperparameters:
self.set_hyperparameters(**hyperparameters)

Expand Down Expand Up @@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
"""Placeholder docstring"""
for k, v in kwargs.items():
value = self._validate_and_cast_hyperparameter(k, v)
self.hyperparam_dict[k] = value
self._hyperparameters[k] = value

self._validate_and_set_default_hyperparameters()

Expand All @@ -225,7 +225,7 @@ def hyperparameters(self):
The fit() method, that does the model training, calls this method to
find the hyperparameters you specified.
"""
return self.hyperparam_dict
return self._hyperparameters

def training_image_uri(self):
"""Returns the docker image to use for training.
Expand Down Expand Up @@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
# Check if all the required hyperparameters are set. If there is a default value
# for one, set it.
for name, definition in self.hyperparameter_definitions.items():
if name not in self.hyperparam_dict:
if name not in self._hyperparameters:
spec = definition["spec"]
if "DefaultValue" in spec:
self.hyperparam_dict[name] = spec["DefaultValue"]
self._hyperparameters[name] = spec["DefaultValue"]
elif "IsRequired" in spec and spec["IsRequired"]:
raise ValueError("Required hyperparameter: %s is not set" % name)

Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
Expand Down Expand Up @@ -158,7 +158,9 @@ def hyperparameters(self):

# remove unset keys.
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
return hyperparameters

def create_model(
Expand Down
511 changes: 409 additions & 102 deletions src/sagemaker/estimator.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re

from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
warn_if_parameter_server_with_multi_gpu,
Expand Down Expand Up @@ -246,13 +246,13 @@ def hyperparameters(self):
distribution=self.distribution
)
hyperparameters.update(
Framework._json_encode_hyperparameters(distributed_training_hyperparameters)
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
)

if self.compiler_config:
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
hyperparameters.update(
Framework._json_encode_hyperparameters(training_compiler_hyperparameters)
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
)

return hyperparameters
Expand Down
10 changes: 4 additions & 6 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,10 @@ def prepare_container_def(
)
deploy_env = copy.deepcopy(self.env)
if self.source_dir or self.dependencies or self.entry_point or self.git_config:
if self.key_prefix or self.git_config:
self._upload_code(deploy_key_prefix, repack=False)
elif self.source_dir and self.entry_point:
self._upload_code(deploy_key_prefix, repack=True)
else:
self._upload_code(deploy_key_prefix, repack=False)
is_repack = (
self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for readability, i would vote for:

is_repack = self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
self._upload_code(deploy_key_prefix, repack=is_repack)

self._upload_code(deploy_key_prefix, repack=is_repack)
deploy_env.update(self._script_mode_env_vars())
return sagemaker.container_def(
self.image_uri, self.model_data, deploy_env, image_config=self.image_config
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from packaging.version import Version

from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
Expand Down Expand Up @@ -192,7 +192,9 @@ def hyperparameters(self):
additional_hyperparameters = self._distribution_configuration(
distribution=self.distribution
)
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
return hyperparameters

def create_model(
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re

from sagemaker import image_uris, fw_utils
from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.model import FrameworkModel, SAGEMAKER_OUTPUT_LOCATION
from sagemaker.mxnet.model import MXNetModel
from sagemaker.tensorflow.model import TensorFlowModel
Expand Down Expand Up @@ -340,7 +340,9 @@ def hyperparameters(self):
SAGEMAKER_ESTIMATOR: SAGEMAKER_ESTIMATOR_VALUE,
}

hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
return hyperparameters

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sagemaker import image_uris, s3, utils
from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
import sagemaker.fw_utils as fw
from sagemaker.tensorflow import defaults
from sagemaker.tensorflow.model import TensorFlowModel
Expand Down Expand Up @@ -327,7 +327,9 @@ def hyperparameters(self):
)
additional_hyperparameters["model_dir"] = self.model_dir

hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
return hyperparameters

def _default_s3_path(self, directory, mpi=False):
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import os
import subprocess
from time import sleep
from sagemaker.fw_utils import UploadedCode


import pytest
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -3350,3 +3352,112 @@ def test_image_name_map(sagemaker_session):
)

assert e.image_uri == IMAGE_URI


@patch("sagemaker.git_utils.git_clone_repo")
def test_git_support_with_branch_and_commit_succeed_estimator_class(
git_clone_repo, sagemaker_session
):
git_clone_repo.side_effect = lambda gitconfig, entrypoint, source_dir=None, dependencies=None: {
"entry_point": "/tmp/repo_dir/entry_point",
"source_dir": None,
"dependencies": None,
}
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
entry_point = "entry_point"
fw = Estimator(
entry_point=entry_point,
git_config=git_config,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
image_uri=IMAGE_URI,
)
fw.fit()
git_clone_repo.assert_called_once_with(git_config, entry_point, None, None)


@patch("sagemaker.estimator.Estimator._stage_user_code_in_s3")
def test_script_mode_estimator(patched_stage_user_code, sagemaker_session):
patched_stage_user_code.return_value = UploadedCode(
s3_prefix="s3://bucket/key", script_name="script_name"
)
script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38"
model_uri = "s3://someprefix2/models/model.tar.gz"
t = Estimator(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
source_dir=script_uri,
image_uri=image_uri,
model_uri=model_uri,
)
t.fit("s3://bucket/mydata")

patched_stage_user_code.assert_called_once()
sagemaker_session.train.assert_called_once()


@patch("time.time", return_value=TIME)
@patch("sagemaker.estimator.tar_and_upload_dir")
def test_script_mode_estimator_same_calls_as_framework(
patched_tar_and_upload_dir, sagemaker_session
):

patched_tar_and_upload_dir.return_value = UploadedCode(
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
)
sagemaker_session.boto_region_name = REGION

script_uri = "s3://codebucket/someprefix/sourcedir.tar.gz"

instance_type = "ml.p2.xlarge"
instance_count = 1

model_uri = "s3://someprefix2/models/model.tar.gz"
training_data_uri = "s3://bucket/mydata"

generic_estimator = Estimator(
entry_point=SCRIPT_PATH,
role=ROLE,
region=REGION,
sagemaker_session=sagemaker_session,
instance_count=instance_count,
instance_type=instance_type,
source_dir=script_uri,
image_uri=IMAGE_URI,
model_uri=model_uri,
environment={"USE_SMDEBUG": "0"},
dependencies=[],
debugger_hook_config={},
)
generic_estimator.fit(training_data_uri)

generic_estimator_tar_and_upload_dir_args = patched_tar_and_upload_dir.call_args_list
generic_estimator_train_args = sagemaker_session.train.call_args_list

patched_tar_and_upload_dir.reset_mock()
sagemaker_session.train.reset_mock()

framework_estimator = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
region=REGION,
source_dir=script_uri,
instance_count=instance_count,
instance_type=instance_type,
sagemaker_session=sagemaker_session,
model_uri=model_uri,
dependencies=[],
debugger_hook_config={},
)
framework_estimator.fit(training_data_uri)

assert len(generic_estimator_tar_and_upload_dir_args) == 1
assert len(generic_estimator_train_args) == 1
assert generic_estimator_tar_and_upload_dir_args == patched_tar_and_upload_dir.call_args_list
assert generic_estimator_train_args == sagemaker_session.train.call_args_list