Skip to content

Commit 9b75787

Browse files
committed
Merge remote-tracking branch 'public/master' into pytorch-release
2 parents ff596ea + c72121b commit 9b75787

File tree

5 files changed

+33
-3
lines changed

5 files changed

+33
-3
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ CHANGELOG
77

88
* bug-fix: Unit Tests: Improve unit test runtime
99
* bug-fix: Estimators: Fix attach for LDA
10+
* bug-fix: Estimators: allow code_location to have no key prefix
11+
* bug-fix: Local Mode: Fix s3 training data download when there is a trailing slash
12+
1013

1114
1.4.1
1215
=====

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def _stage_user_code_in_s3(self):
571571
code_s3_prefix = '{}/source'.format(self._current_job_name)
572572
else:
573573
code_bucket, key_prefix = parse_s3_url(self.code_location)
574-
code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)
574+
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source']))
575575

576576
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
577577
bucket=code_bucket,

src/sagemaker/local/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@ def _download_folder(self, bucket_name, prefix, target):
267267

268268
for obj_sum in bucket.objects.filter(Prefix=prefix):
269269
obj = s3.Object(obj_sum.bucket_name, obj_sum.key)
270-
file_path = os.path.join(target, obj_sum.key[len(prefix) + 1:])
270+
s3_relative_path = obj_sum.key[len(prefix):].lstrip('/')
271+
file_path = os.path.join(target, s3_relative_path)
271272

272273
try:
273274
os.makedirs(os.path.dirname(file_path))
274275
except OSError as exc:
275276
if exc.errno != errno.EEXIST:
276277
raise
277278
pass
278-
279279
obj.download_file(file_path)
280280

281281
def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters):

tests/unit/test_estimator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,24 @@ def test_custom_code_bucket(time, sagemaker_session):
127127
assert train_kwargs['hyperparameters']['sagemaker_submit_directory'] == json.dumps(expected_submit_dir)
128128

129129

130+
@patch('time.strftime', return_value=TIMESTAMP)
131+
def test_custom_code_bucket_without_prefix(time, sagemaker_session):
132+
code_bucket = 'codebucket'
133+
code_location = 's3://{}'.format(code_bucket)
134+
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
135+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
136+
code_location=code_location)
137+
t.fit('s3://bucket/mydata')
138+
139+
expected_key = '{}/source/sourcedir.tar.gz'.format(JOB_NAME)
140+
_, s3_args, _ = sagemaker_session.boto_session.resource('s3').Object.mock_calls[0]
141+
assert s3_args == (code_bucket, expected_key)
142+
143+
expected_submit_dir = 's3://{}/{}'.format(code_bucket, expected_key)
144+
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
145+
assert train_kwargs['hyperparameters']['sagemaker_submit_directory'] == json.dumps(expected_submit_dir)
146+
147+
130148
def test_invalid_custom_code_bucket(sagemaker_session):
131149
code_location = 'thisllworkright?'
132150
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,

tests/unit/test_image.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,15 @@ def test_download_folder(makedirs):
366366
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
367367
call(os.path.join('/tmp', 'train/validation_data.csv'))]
368368
obj_mock.download_file.assert_has_calls(calls)
369+
obj_mock.reset_mock()
370+
371+
# Testing with a trailing slash for the prefix.
372+
sagemaker_container._download_folder(BUCKET_NAME, '/prefix/', '/tmp')
373+
obj_mock.download_file.assert_called()
374+
calls = [call(os.path.join('/tmp', 'train/train_data.csv')),
375+
call(os.path.join('/tmp', 'train/validation_data.csv'))]
376+
377+
obj_mock.download_file.assert_has_calls(calls)
369378

370379

371380
def test_ecr_login_non_ecr():

0 commit comments

Comments
 (0)