Skip to content

Commit 4074e75

Browse files
committed
don't allow for absolute path entry_point with S3 source_dir
1 parent 450037b commit 4074e75

File tree

10 files changed

+29
-64
lines changed

10 files changed

+29
-64
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def create_model(
216216
return ChainerModel(
217217
self.model_data,
218218
role or self.role,
219-
entry_point or self.uploaded_code.script_name,
219+
entry_point or self._model_entry_point(),
220220
source_dir=(source_dir or self._model_source_dir()),
221221
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
222222
container_log_level=self.container_log_level,

src/sagemaker/estimator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,17 +1755,25 @@ def _stage_user_code_in_s3(self):
17551755
)
17561756

17571757
def _model_source_dir(self):
1758-
"""Get the appropriate value to pass as source_dir to model constructor
1759-
on deploying
1758+
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
17601759
17611760
Returns:
1762-
str: Either a local or an S3 path pointing to the source_dir to be
1763-
used for code by the model to be deployed
1761+
str: Either a local or an S3 path pointing to the ``source_dir`` to be
1762+
used for code by the model to be deployed
17641763
"""
17651764
return (
17661765
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
17671766
)
17681767

1768+
def _model_entry_point(self):
1769+
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.
1770+
1771+
Returns:
1772+
str: The path to the entry point script. This can be either an absolute path or
1773+
a path relative to ``self._model_source_dir()``.
1774+
"""
1775+
return self.uploaded_code.script_name if self._model_source_dir() else self.entry_point
1776+
17691777
def hyperparameters(self):
17701778
"""Return the hyperparameters as a dictionary to use for training.
17711779

src/sagemaker/mxnet/estimator.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ def create_model(
217217
if "name" not in kwargs:
218218
kwargs["name"] = self._current_job_name
219219

220-
model = MXNetModel(
220+
return MXNetModel(
221221
self.model_data,
222222
role or self.role,
223-
entry_point,
223+
entry_point or self._model_entry_point(),
224224
source_dir=(source_dir or self._model_source_dir()),
225225
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
226226
container_log_level=self.container_log_level,
@@ -234,13 +234,6 @@ def create_model(
234234
**kwargs
235235
)
236236

237-
if entry_point is None:
238-
model.entry_point = (
239-
self.entry_point if model._is_mms_version() else self.uploaded_code.script_name
240-
)
241-
242-
return model
243-
244237
@classmethod
245238
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
246239
"""Convert the job description to init params that can be handled by the

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def create_model(
176176
return PyTorchModel(
177177
self.model_data,
178178
role or self.role,
179-
entry_point or self.uploaded_code.script_name,
179+
entry_point or self._model_entry_point(),
180180
source_dir=(source_dir or self._model_source_dir()),
181181
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
182182
container_log_level=self.container_log_level,

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def create_model(
229229
if not entry_point and (source_dir or dependencies):
230230
raise AttributeError("Please provide an `entry_point`.")
231231

232-
entry_point = entry_point or self.uploaded_code.script_name
232+
entry_point = entry_point or self._model_entry_point()
233233
source_dir = source_dir or self._model_source_dir()
234234
dependencies = dependencies or self.dependencies
235235

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def create_model(
194194
return SKLearnModel(
195195
self.model_data,
196196
role,
197-
entry_point or self.uploaded_code.script_name,
197+
entry_point or self._model_entry_point(),
198198
source_dir=(source_dir or self._model_source_dir()),
199199
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
200200
container_log_level=self.container_log_level,

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def _create_default_model(
660660
return TensorFlowModel(
661661
self.model_data,
662662
role,
663-
entry_point or self.uploaded_code.script_name,
663+
entry_point or self._model_entry_point(),
664664
source_dir=source_dir or self._model_source_dir(),
665665
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
666666
env={"SAGEMAKER_REQUIREMENTS": self.requirements_file},

src/sagemaker/xgboost/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def create_model(
174174
return XGBoostModel(
175175
self.model_data,
176176
role,
177-
entry_point or self.uploaded_code.script_name,
177+
entry_point or self._model_entry_point(),
178178
framework_version=self.framework_version,
179179
source_dir=(source_dir or self._model_source_dir()),
180180
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,

tests/integ/test_mxnet_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def mxnet_training_job(sagemaker_session, mxnet_full_version, cpu_instance_type)
3939
)
4040

4141
mx = MXNet(
42-
entry_point=os.path.join("mxnet_mnist", "mnist.py"),
42+
entry_point="mxnet_mnist/mnist.py",
4343
source_dir=s3_source,
4444
role="SageMakerRole",
4545
framework_version=mxnet_full_version,

tests/unit/test_mxnet.py

Lines changed: 8 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,15 @@ def test_create_model(sagemaker_session, mxnet_version):
180180
container_log_level = '"logging.INFO"'
181181
source_dir = "s3://mybucket/source"
182182
mx = MXNet(
183-
entry_point=SCRIPT_PATH,
183+
entry_point=SCRIPT_NAME,
184+
source_dir=source_dir,
184185
role=ROLE,
185186
sagemaker_session=sagemaker_session,
186187
train_instance_count=INSTANCE_COUNT,
187188
train_instance_type=INSTANCE_TYPE,
188189
framework_version=mxnet_version,
189190
container_log_level=container_log_level,
190191
base_job_name="job",
191-
source_dir=source_dir,
192192
)
193193

194194
job_name = "new_name"
@@ -198,6 +198,7 @@ def test_create_model(sagemaker_session, mxnet_version):
198198
assert model.sagemaker_session == sagemaker_session
199199
assert model.framework_version == mxnet_version
200200
assert model.py_version == mx.py_version
201+
assert model.entry_point == SCRIPT_NAME
201202
assert model.role == ROLE
202203
assert model.name == job_name
203204
assert model.container_log_level == container_log_level
@@ -206,55 +207,19 @@ def test_create_model(sagemaker_session, mxnet_version):
206207
assert model.vpc_config is None
207208

208209

209-
@patch("sagemaker.utils.create_tar_file", MagicMock())
210-
def test_create_model_default_entry_with_mms(
211-
sagemaker_session, mxnet_version, skip_if_not_mms_version
212-
):
213-
mx = MXNet(
214-
entry_point=SCRIPT_PATH,
215-
role=ROLE,
216-
sagemaker_session=sagemaker_session,
217-
train_instance_count=INSTANCE_COUNT,
218-
train_instance_type=INSTANCE_TYPE,
219-
framework_version=mxnet_version,
220-
)
221-
222-
mx.fit()
223-
model = mx.create_model()
224-
225-
assert model.entry_point == SCRIPT_PATH
226-
227-
228-
@patch("sagemaker.utils.create_tar_file", MagicMock())
229-
def test_create_model_default_entry_not_mms(sagemaker_session, mxnet_version, skip_if_mms_version):
230-
mx = MXNet(
231-
entry_point=SCRIPT_PATH,
232-
role=ROLE,
233-
sagemaker_session=sagemaker_session,
234-
train_instance_count=INSTANCE_COUNT,
235-
train_instance_type=INSTANCE_TYPE,
236-
framework_version=mxnet_version,
237-
)
238-
239-
mx.fit()
240-
model = mx.create_model()
241-
242-
assert model.entry_point == SCRIPT_NAME
243-
244-
245210
def test_create_model_with_optional_params(sagemaker_session):
246211
container_log_level = '"logging.INFO"'
247212
source_dir = "s3://mybucket/source"
248213
enable_cloudwatch_metrics = "true"
249214
mx = MXNet(
250-
entry_point=SCRIPT_PATH,
215+
entry_point=SCRIPT_NAME,
216+
source_dir=source_dir,
251217
role=ROLE,
252218
sagemaker_session=sagemaker_session,
253219
train_instance_count=INSTANCE_COUNT,
254220
train_instance_type=INSTANCE_TYPE,
255221
container_log_level=container_log_level,
256222
base_job_name="job",
257-
source_dir=source_dir,
258223
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
259224
)
260225

@@ -286,15 +251,15 @@ def test_create_model_with_custom_image(sagemaker_session):
286251
source_dir = "s3://mybucket/source"
287252
custom_image = "mxnet:2.0"
288253
mx = MXNet(
289-
entry_point=SCRIPT_PATH,
254+
entry_point=SCRIPT_NAME,
255+
source_dir=source_dir,
290256
role=ROLE,
291257
sagemaker_session=sagemaker_session,
292258
train_instance_count=INSTANCE_COUNT,
293259
train_instance_type=INSTANCE_TYPE,
294260
image_name=custom_image,
295261
container_log_level=container_log_level,
296262
base_job_name="job",
297-
source_dir=source_dir,
298263
)
299264

300265
job_name = "new_name"
@@ -303,7 +268,7 @@ def test_create_model_with_custom_image(sagemaker_session):
303268

304269
assert model.sagemaker_session == sagemaker_session
305270
assert model.image == custom_image
306-
assert model.entry_point == SCRIPT_PATH
271+
assert model.entry_point == SCRIPT_NAME
307272
assert model.role == ROLE
308273
assert model.name == job_name
309274
assert model.container_log_level == container_log_level
@@ -823,7 +788,6 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
823788
image_name=custom_image,
824789
container_log_level=container_log_level,
825790
base_job_name="job",
826-
source_dir=source_dir,
827791
)
828792

829793
mx.fit(inputs="s3://mybucket/train", job_name="new_name")

0 commit comments

Comments
 (0)