Skip to content

Commit 5b3ab50

Browse files
laurenyujesterhazy
authored andcommitted
Add a check for S3 paths being incorrectly passed as an entry point (#500)
1 parent 13e4040 commit 5b3ab50

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
740740
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
741741
742742
Args:
743-
entry_point (str): Path (absolute or relative) to the Python source file which should be executed
743+
entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
744744
as the entry point to training. This should be compatible with either Python 2.7 or Python 3.5.
745745
source_dir (str): Path (absolute or relative) to a directory with any other training
746746
source code dependencies aside from tne entry point file (default: None). Structure within this
@@ -779,9 +779,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
779779
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
780780
"""
781781
super(Framework, self).__init__(**kwargs)
782+
if entry_point.startswith('s3://'):
783+
raise ValueError('Invalid entry point script: {}. Must be a path to a local file.'.format(entry_point))
784+
self.entry_point = entry_point
782785
self.source_dir = source_dir
783786
self.dependencies = dependencies or []
784-
self.entry_point = entry_point
785787
if enable_cloudwatch_metrics:
786788
warnings.warn('enable_cloudwatch_metrics is now deprecated and will be removed in the future.',
787789
DeprecationWarning)

tests/unit/test_estimator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,14 @@ def test_framework_all_init_args(sagemaker_session):
171171
'metric_definitions': [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}]}
172172

173173

174+
def test_framework_init_s3_entry_point_invalid(sagemaker_session):
175+
with pytest.raises(ValueError) as error:
176+
DummyFramework('s3://remote-script-because-im-mistaken', role=ROLE,
177+
sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT,
178+
train_instance_type=INSTANCE_TYPE)
179+
assert 'Must be a path to a local file' in str(error)
180+
181+
174182
def test_sagemaker_s3_uri_invalid(sagemaker_session):
175183
with pytest.raises(ValueError) as error:
176184
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)