Skip to content

breaking: preserve script path when S3 source_dir is provided #941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f0d4399
fix: preserve script path when S3 source_dir is provided
laurenyu Jul 17, 2019
c861ce7
add unit test
laurenyu Jul 17, 2019
7464a90
Merge branch 'master' into script-path
laurenyu Jul 23, 2019
2b3c149
fix estimator --> model entry point logic
laurenyu Jul 23, 2019
95a30f3
fix MXNet logic
laurenyu Jul 25, 2019
3b49b6d
add docstring and unit test
laurenyu Jul 25, 2019
6e16ab9
add integ test
laurenyu Jul 25, 2019
139d060
Merge branch 'master' into script-path
icywang86rui Aug 6, 2019
98818da
Merge branch 'master' into script-path
jesterhazy Aug 8, 2019
b0d1b11
testing
laurenyu Sep 5, 2019
d6cf3b5
testing
laurenyu Sep 10, 2019
564c8b1
try chainer
laurenyu Sep 10, 2019
3d75ac0
Merge branch 'master' into script-path
laurenyu Dec 10, 2019
921181e
black format
laurenyu Dec 10, 2019
ab438ca
Merge branch 'master' into script-path
laurenyu Dec 17, 2019
a7d7004
fix merge
laurenyu Dec 17, 2019
08512cf
Merge branch 'master' into script-path
laurenyu Jan 28, 2020
a9489e2
address pylint
laurenyu Jan 28, 2020
4d949e1
fix _is_mms_version() call
laurenyu Jan 29, 2020
f2fd196
Merge branch 'master' into script-path
laurenyu Jan 31, 2020
4e70f2e
Merge branch 'master' into script-path
knakad Feb 18, 2020
450037b
Merge branch 'master' into script-path
laurenyu Jun 4, 2020
4074e75
don't allow for absolute path entry_point with S3 source_dir
laurenyu Jun 4, 2020
36ec1a7
address flake8
laurenyu Jun 5, 2020
3933900
address local mode and MXNet + MMS
laurenyu Jun 5, 2020
a635631
add forgotten return statement
laurenyu Jun 9, 2020
98f4178
update test for new MMS usage (aka not supported)
laurenyu Jun 9, 2020
2b58b04
fix integ test
laurenyu Jun 10, 2020
55356dc
Merge branch 'zwei' into script-path
laurenyu Jul 15, 2020
fa3debc
Merge branch 'zwei' into script-path
laurenyu Jul 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def create_model(
return ChainerModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
Expand Down
19 changes: 15 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,17 +1755,28 @@ def _stage_user_code_in_s3(self):
)

def _model_source_dir(self):
"""Get the appropriate value to pass as source_dir to model constructor
on deploying
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.

Returns:
str: Either a local or an S3 path pointing to the source_dir to be
used for code by the model to be deployed
str: Either a local or an S3 path pointing to the ``source_dir`` to be
used for code by the model to be deployed
"""
return (
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
)

def _model_entry_point(self):
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.

Returns:
str: The path to the entry point script. This can be either an absolute path or
a path relative to ``self._model_source_dir()``.
"""
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
return self.entry_point

return self.uploaded_code.script_name

def hyperparameters(self):
"""Return the hyperparameters as a dictionary to use for training.

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def tar_and_upload_dir(
script name.
"""
if directory and directory.lower().startswith("s3://"):
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
return UploadedCode(s3_prefix=directory, script_name=script)

script_name = script if directory else os.path.basename(script)
dependencies = dependencies or []
Expand Down
11 changes: 9 additions & 2 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def create_model(
if "name" not in kwargs:
kwargs["name"] = self._current_job_name

return MXNetModel(
model = MXNetModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
entry_point,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
Expand All @@ -234,6 +234,13 @@ def create_model(
**kwargs
)

if entry_point is None:
model.entry_point = (
self.entry_point if model._is_mms_version() else self._model_entry_point()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:(

)

return model

@classmethod
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
"""Convert the job description to init params that can be handled by the
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def create_model(
return PyTorchModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def create_model(
if not entry_point and (source_dir or dependencies):
raise AttributeError("Please provide an `entry_point`.")

entry_point = entry_point or self.entry_point
entry_point = entry_point or self._model_entry_point()
source_dir = source_dir or self._model_source_dir()
dependencies = dependencies or self.dependencies

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def create_model(
return SKLearnModel(
self.model_data,
role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def _create_default_model(
return TensorFlowModel(
self.model_data,
role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
source_dir=source_dir or self._model_source_dir(),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def create_model(
return XGBoostModel(
self.model_data,
role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
framework_version=self.framework_version,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
Expand Down
Binary file added tests/data/mxnet_mnist/sourcedir.tar.gz
Binary file not shown.
21 changes: 16 additions & 5 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@
@pytest.fixture(scope="module")
def mxnet_training_job(sagemaker_session, mxnet_full_version, cpu_instance_type):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
s3_prefix = "integ-test-data/mxnet_mnist"
data_path = os.path.join(DATA_DIR, "mxnet_mnist")

s3_source = sagemaker_session.upload_data(
path=os.path.join(data_path, "sourcedir.tar.gz"), key_prefix="{}/src".format(s3_prefix)
)

mx = MXNet(
entry_point=script_path,
entry_point="mxnet_mnist/mnist.py",
source_dir=s3_source,
role="SageMakerRole",
framework_version=mxnet_full_version,
py_version=PYTHON_VERSION,
Expand All @@ -45,10 +50,10 @@ def mxnet_training_job(sagemaker_session, mxnet_full_version, cpu_instance_type)
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
path=os.path.join(data_path, "train"), key_prefix="{}/train".format(s3_prefix)
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
path=os.path.join(data_path, "test"), key_prefix="{}/test".format(s3_prefix)
)

mx.fit({"train": train_input, "test": test_input})
Expand All @@ -62,7 +67,13 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
predictor = estimator.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor = estimator.deploy(
1,
cpu_instance_type,
entry_point="mnist.py",
source_dir=os.path.join(DATA_DIR, "mxnet_mnist"),
endpoint_name=endpoint_name,
)
data = numpy.zeros(shape=(1, 1, 28, 28))
result = predictor.predict(data)
assert result is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ def test_git_support_codecommit_with_ssh_no_passphrase_needed(git_clone_repo, sa
@patch("time.strftime", return_value=TIMESTAMP)
def test_init_with_source_dir_s3(strftime, sagemaker_session):
fw = DummyFramework(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir="s3://location",
role=ROLE,
sagemaker_session=sagemaker_session,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,18 @@ def test_tar_and_upload_dir_s3(sagemaker_session):
assert result == fw_utils.UploadedCode("s3://m", "mnist.py")


def test_tar_and_upload_dir_s3_with_script_dir(sagemaker_session):
bucket = "mybucket"
s3_key_prefix = "something/source"
script = "some/dir/mnist.py"
directory = "s3://m"
result = fw_utils.tar_and_upload_dir(
sagemaker_session, bucket, s3_key_prefix, script, directory
)

assert result == fw_utils.UploadedCode("s3://m", "some/dir/mnist.py")


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
bucket = "mybucket"
Expand Down
21 changes: 10 additions & 11 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sagemaker.mxnet import MXNetPredictor, MXNetModel

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
SCRIPT_NAME = "dummy_script.py"
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME)
SERVING_SCRIPT_FILE = "another_dummy_script.py"
MODEL_DATA = "s3://mybucket/model"
ENV = {"DUMMY_ENV_VAR": "dummy_value"}
Expand Down Expand Up @@ -179,15 +180,15 @@ def test_create_model(sagemaker_session, mxnet_version):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
mx = MXNet(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir=source_dir,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
framework_version=mxnet_version,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
)

job_name = "new_name"
Expand All @@ -197,7 +198,7 @@ def test_create_model(sagemaker_session, mxnet_version):
assert model.sagemaker_session == sagemaker_session
assert model.framework_version == mxnet_version
assert model.py_version == mx.py_version
assert model.entry_point == SCRIPT_PATH
assert model.entry_point == SCRIPT_NAME
assert model.role == ROLE
assert model.name == job_name
assert model.container_log_level == container_log_level
Expand All @@ -211,14 +212,14 @@ def test_create_model_with_optional_params(sagemaker_session):
source_dir = "s3://mybucket/source"
enable_cloudwatch_metrics = "true"
mx = MXNet(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir=source_dir,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
)

Expand Down Expand Up @@ -250,15 +251,15 @@ def test_create_model_with_custom_image(sagemaker_session):
source_dir = "s3://mybucket/source"
custom_image = "mxnet:2.0"
mx = MXNet(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir=source_dir,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
image_name=custom_image,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
)

job_name = "new_name"
Expand All @@ -267,7 +268,7 @@ def test_create_model_with_custom_image(sagemaker_session):

assert model.sagemaker_session == sagemaker_session
assert model.image == custom_image
assert model.entry_point == SCRIPT_PATH
assert model.entry_point == SCRIPT_NAME
assert model.role == ROLE
assert model.name == job_name
assert model.container_log_level == container_log_level
Expand Down Expand Up @@ -775,7 +776,6 @@ def test_model_empty_framework_version(warning, sagemaker_session):

def test_create_model_with_custom_hosting_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
custom_image = "mxnet:2.0"
custom_hosting_image = "mxnet_hosting:2.0"
mx = MXNet(
Expand All @@ -787,7 +787,6 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
image_name=custom_image,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
)

mx.fit(inputs="s3://mybucket/train", job_name="new_name")
Expand Down