Skip to content

Commit 211f4e5

Browse files
authored
breaking: preserve script path when S3 source_dir is provided (#941)
1 parent db21a38 commit 211f4e5

File tree

13 files changed

+69
-29
lines changed

13 files changed

+69
-29
lines changed

src/sagemaker/chainer/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def create_model(
214214
return ChainerModel(
215215
self.model_data,
216216
role or self.role,
217-
entry_point or self.entry_point,
217+
entry_point or self._model_entry_point(),
218218
source_dir=(source_dir or self._model_source_dir()),
219219
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
220220
container_log_level=self.container_log_level,

src/sagemaker/estimator.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1734,17 +1734,28 @@ def _stage_user_code_in_s3(self):
17341734
)
17351735

17361736
def _model_source_dir(self):
1737-
"""Get the appropriate value to pass as source_dir to model constructor
1738-
on deploying
1737+
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
17391738
17401739
Returns:
1741-
str: Either a local or an S3 path pointing to the source_dir to be
1742-
used for code by the model to be deployed
1740+
str: Either a local or an S3 path pointing to the ``source_dir`` to be
1741+
used for code by the model to be deployed
17431742
"""
17441743
return (
17451744
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
17461745
)
17471746

1747+
def _model_entry_point(self):
1748+
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.
1749+
1750+
Returns:
1751+
str: The path to the entry point script. This can be either an absolute path or
1752+
a path relative to ``self._model_source_dir()``.
1753+
"""
1754+
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
1755+
return self.entry_point
1756+
1757+
return self.uploaded_code.script_name
1758+
17481759
def hyperparameters(self):
17491760
"""Return the hyperparameters as a dictionary to use for training.
17501761

src/sagemaker/fw_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def tar_and_upload_dir(
447447
script name.
448448
"""
449449
if directory and directory.lower().startswith("s3://"):
450-
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
450+
return UploadedCode(s3_prefix=directory, script_name=script)
451451

452452
script_name = script if directory else os.path.basename(script)
453453
dependencies = dependencies or []

src/sagemaker/mxnet/estimator.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def create_model(
218218

219219
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
220220

221-
return MXNetModel(
221+
model = MXNetModel(
222222
self.model_data,
223223
role or self.role,
224-
entry_point or self.entry_point,
224+
entry_point,
225225
framework_version=self.framework_version,
226226
py_version=self.py_version,
227227
source_dir=(source_dir or self._model_source_dir()),
@@ -235,6 +235,13 @@ def create_model(
235235
**kwargs
236236
)
237237

238+
if entry_point is None:
239+
model.entry_point = (
240+
self.entry_point if model._is_mms_version() else self._model_entry_point()
241+
)
242+
243+
return model
244+
238245
@classmethod
239246
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
240247
"""Convert the job description to init params that can be handled by the

src/sagemaker/pytorch/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def create_model(
175175
return PyTorchModel(
176176
self.model_data,
177177
role or self.role,
178-
entry_point or self.entry_point,
178+
entry_point or self._model_entry_point(),
179179
framework_version=self.framework_version,
180180
py_version=self.py_version,
181181
source_dir=(source_dir or self._model_source_dir()),

src/sagemaker/rl/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def create_model(
232232
if not entry_point and (source_dir or dependencies):
233233
raise AttributeError("Please provide an `entry_point`.")
234234

235-
entry_point = entry_point or self.entry_point
235+
entry_point = entry_point or self._model_entry_point()
236236
source_dir = source_dir or self._model_source_dir()
237237
dependencies = dependencies or self.dependencies
238238

src/sagemaker/sklearn/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def create_model(
196196
return SKLearnModel(
197197
self.model_data,
198198
role,
199-
entry_point or self.entry_point,
199+
entry_point or self._model_entry_point(),
200200
source_dir=(source_dir or self._model_source_dir()),
201201
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
202202
container_log_level=self.container_log_level,

src/sagemaker/xgboost/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_model(
172172
return XGBoostModel(
173173
self.model_data,
174174
role,
175-
entry_point or self.entry_point,
175+
entry_point or self._model_entry_point(),
176176
framework_version=self.framework_version,
177177
source_dir=(source_dir or self._model_source_dir()),
178178
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
2.15 KB
Binary file not shown.

tests/integ/test_mxnet.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ def mxnet_training_job(
3232
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
3333
):
3434
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
35-
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
35+
s3_prefix = "integ-test-data/mxnet_mnist"
3636
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
3737

38+
s3_source = sagemaker_session.upload_data(
39+
path=os.path.join(data_path, "sourcedir.tar.gz"), key_prefix="{}/src".format(s3_prefix)
40+
)
41+
3842
mx = MXNet(
39-
entry_point=script_path,
43+
entry_point="mxnet_mnist/mnist.py",
44+
source_dir=s3_source,
4045
role="SageMakerRole",
4146
framework_version=mxnet_full_version,
4247
py_version=mxnet_full_py_version,
@@ -46,10 +51,10 @@ def mxnet_training_job(
4651
)
4752

4853
train_input = mx.sagemaker_session.upload_data(
49-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
54+
path=os.path.join(data_path, "train"), key_prefix="{}/train".format(s3_prefix)
5055
)
5156
test_input = mx.sagemaker_session.upload_data(
52-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
57+
path=os.path.join(data_path, "test"), key_prefix="{}/test".format(s3_prefix)
5358
)
5459

5560
mx.fit({"train": train_input, "test": test_input})
@@ -62,7 +67,13 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
6267

6368
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
6469
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
65-
predictor = estimator.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
70+
predictor = estimator.deploy(
71+
1,
72+
cpu_instance_type,
73+
entry_point="mnist.py",
74+
source_dir=os.path.join(DATA_DIR, "mxnet_mnist"),
75+
endpoint_name=endpoint_name,
76+
)
6677
data = numpy.zeros(shape=(1, 1, 28, 28))
6778
result = predictor.predict(data)
6879
assert result is not None

tests/unit/test_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1269,7 +1269,7 @@ def test_git_support_codecommit_with_ssh_no_passphrase_needed(git_clone_repo, sa
12691269
@patch("time.strftime", return_value=TIMESTAMP)
12701270
def test_init_with_source_dir_s3(strftime, sagemaker_session):
12711271
fw = DummyFramework(
1272-
entry_point=SCRIPT_PATH,
1272+
entry_point=SCRIPT_NAME,
12731273
source_dir="s3://location",
12741274
role=ROLE,
12751275
sagemaker_session=sagemaker_session,

tests/unit/test_fw_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,18 @@ def test_tar_and_upload_dir_s3(sagemaker_session):
873873
assert result == fw_utils.UploadedCode("s3://m", "mnist.py")
874874

875875

876+
def test_tar_and_upload_dir_s3_with_script_dir(sagemaker_session):
877+
bucket = "mybucket"
878+
s3_key_prefix = "something/source"
879+
script = "some/dir/mnist.py"
880+
directory = "s3://m"
881+
result = fw_utils.tar_and_upload_dir(
882+
sagemaker_session, bucket, s3_key_prefix, script, directory
883+
)
884+
885+
assert result == fw_utils.UploadedCode("s3://m", "some/dir/mnist.py")
886+
887+
876888
@patch("sagemaker.utils")
877889
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
878890
bucket = "mybucket"

tests/unit/test_mxnet.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from sagemaker.mxnet import MXNetPredictor, MXNetModel
2828

2929
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
30-
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
30+
SCRIPT_NAME = "dummy_script.py"
31+
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME)
3132
SERVING_SCRIPT_FILE = "another_dummy_script.py"
3233
MODEL_DATA = "s3://mybucket/model"
3334
ENV = {"DUMMY_ENV_VAR": "dummy_value"}
@@ -189,7 +190,8 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
189190
base_job_name = "job"
190191

191192
mx = MXNet(
192-
entry_point=SCRIPT_PATH,
193+
entry_point=SCRIPT_NAME,
194+
source_dir=source_dir,
193195
framework_version=mxnet_version,
194196
py_version=mxnet_py_version,
195197
role=ROLE,
@@ -198,7 +200,6 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
198200
instance_type=INSTANCE_TYPE,
199201
container_log_level=container_log_level,
200202
base_job_name=base_job_name,
201-
source_dir=source_dir,
202203
)
203204

204205
mx.fit(inputs="s3://mybucket/train", job_name="new_name")
@@ -210,7 +211,7 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
210211
assert model.sagemaker_session == sagemaker_session
211212
assert model.framework_version == mxnet_version
212213
assert model.py_version == mxnet_py_version
213-
assert model.entry_point == SCRIPT_PATH
214+
assert model.entry_point == SCRIPT_NAME
214215
assert model.role == ROLE
215216
assert model.name == model_name
216217
assert model.container_log_level == container_log_level
@@ -226,7 +227,8 @@ def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxn
226227
source_dir = "s3://mybucket/source"
227228
enable_cloudwatch_metrics = "true"
228229
mx = MXNet(
229-
entry_point=SCRIPT_PATH,
230+
entry_point=SCRIPT_NAME,
231+
source_dir=source_dir,
230232
framework_version=mxnet_version,
231233
py_version=mxnet_py_version,
232234
role=ROLE,
@@ -235,7 +237,6 @@ def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxn
235237
instance_type=INSTANCE_TYPE,
236238
container_log_level=container_log_level,
237239
base_job_name="job",
238-
source_dir=source_dir,
239240
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
240241
)
241242

@@ -270,7 +271,8 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
270271
base_job_name = "job"
271272

272273
mx = MXNet(
273-
entry_point=SCRIPT_PATH,
274+
entry_point=SCRIPT_NAME,
275+
source_dir=source_dir,
274276
framework_version="2.0",
275277
py_version="py3",
276278
role=ROLE,
@@ -280,7 +282,6 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
280282
image_uri=custom_image,
281283
container_log_level=container_log_level,
282284
base_job_name=base_job_name,
283-
source_dir=source_dir,
284285
)
285286

286287
mx.fit(inputs="s3://mybucket/train", job_name="new_name")
@@ -291,7 +292,7 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
291292

292293
assert model.sagemaker_session == sagemaker_session
293294
assert model.image_uri == custom_image
294-
assert model.entry_point == SCRIPT_PATH
295+
assert model.entry_point == SCRIPT_NAME
295296
assert model.role == ROLE
296297
assert model.name == model_name
297298
assert model.container_log_level == container_log_level
@@ -730,7 +731,6 @@ def test_model_py2_warning(warning, sagemaker_session):
730731

731732
def test_create_model_with_custom_hosting_image(sagemaker_session):
732733
container_log_level = '"logging.INFO"'
733-
source_dir = "s3://mybucket/source"
734734
custom_image = "mxnet:2.0"
735735
custom_hosting_image = "mxnet_hosting:2.0"
736736
mx = MXNet(
@@ -744,7 +744,6 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
744744
image_uri=custom_image,
745745
container_log_level=container_log_level,
746746
base_job_name="job",
747-
source_dir=source_dir,
748747
)
749748

750749
mx.fit(inputs="s3://mybucket/train", job_name="new_name")

0 commit comments

Comments
 (0)