Skip to content

Commit 80b5e18

Browse files
committed
fix: correct S3Downloader behavior (#292)
1 parent 15793e2 commit 80b5e18

File tree

2 files changed

+170
-21
lines changed

2 files changed

+170
-21
lines changed

src/sagemaker/session.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
239239

240240
# Initialize the variables used to loop through the contents of the S3 bucket.
241241
keys = []
242-
directories = []
243242
next_token = ""
244243
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
245244

246-
# Loop through the contents of the bucket, 1,000 objects at a time.
245+
# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
246+
# a "keys" list.
247247
while next_token is not None:
248248
request_parameters = base_parameters.copy()
249249
if next_token != "":
@@ -253,26 +253,20 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
253253
# For each object, save its key or directory.
254254
for s3_object in contents:
255255
key = s3_object.get("Key")
256-
if key[-1] != "/":
257-
keys.append(key)
258-
else:
259-
directories.append(key)
256+
keys.append(key)
260257
next_token = response.get("NextContinuationToken")
261258

262-
# For each directory, create the directory on the local machine.
263-
for directory in directories:
264-
destination_path = os.path.join(path, directory)
265-
if not os.path.exists(os.path.dirname(destination_path)):
266-
os.makedirs(os.path.dirname(destination_path))
267-
268-
# For each object key, create the directory on the local machine,
269-
# and then download the file.
259+
# For each object key, create the directory on the local machine if needed, and then
260+
# download the file.
270261
for key in keys:
271-
destination_path = os.path.join(path, key)
262+
tail_s3_uri_path = os.path.basename(key_prefix)
263+
if not os.path.splitext(key_prefix)[1]:
264+
tail_s3_uri_path = os.path.relpath(key, key_prefix)
265+
destination_path = os.path.join(path, tail_s3_uri_path)
272266
if not os.path.exists(os.path.dirname(destination_path)):
273267
os.makedirs(os.path.dirname(destination_path))
274268
s3.download_file(
275-
bucket=bucket, key=key, filename=destination_path, extra_args=extra_args
269+
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
276270
)
277271

278272
def read_s3_file(self, bucket, key_prefix):

tests/integ/test_s3.py

Lines changed: 160 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,108 @@
2323
from tests.integ.kms_utils import get_or_create_kms_key
2424

2525

26+
TMP_BASE_PATH = "/tmp"
27+
28+
2629
@pytest.fixture(scope="module")
2730
def s3_files_kms_key(sagemaker_session):
2831
return get_or_create_kms_key(sagemaker_session=sagemaker_session)
2932

3033

31-
def test_statistics_object_creation_from_s3_uri_with_customizations(
34+
def test_s3_uploader_and_downloader_reads_files_when_given_file_name_uris(
3235
sagemaker_session, s3_files_kms_key
3336
):
34-
file_1_body = "First File Body."
35-
file_1_name = "first_file.txt"
36-
file_2_body = "Second File Body."
37-
file_2_name = "second_file.txt"
37+
my_uuid = str(uuid.uuid4())
38+
39+
file_1_body = "First File Body {}.".format(my_uuid)
40+
file_1_name = "first_file_{}.txt".format(my_uuid)
41+
file_2_body = "Second File Body {}.".format(my_uuid)
42+
file_2_name = "second_file_{}.txt".format(my_uuid)
43+
44+
base_s3_uri = os.path.join(
45+
"s3://", sagemaker_session.default_bucket(), "integ-test-test-s3-list", my_uuid
46+
)
47+
file_1_s3_uri = os.path.join(base_s3_uri, file_1_name)
48+
file_2_s3_uri = os.path.join(base_s3_uri, file_2_name)
49+
50+
S3Uploader.upload_string_as_file_body(
51+
body=file_1_body,
52+
desired_s3_uri=file_1_s3_uri,
53+
kms_key=s3_files_kms_key,
54+
session=sagemaker_session,
55+
)
56+
57+
S3Uploader.upload_string_as_file_body(
58+
body=file_2_body,
59+
desired_s3_uri=file_2_s3_uri,
60+
kms_key=s3_files_kms_key,
61+
session=sagemaker_session,
62+
)
63+
64+
s3_uris = S3Downloader.list(s3_uri=base_s3_uri, session=sagemaker_session)
65+
66+
assert file_1_name in s3_uris[0]
67+
assert file_2_name in s3_uris[1]
68+
69+
assert file_1_body == S3Downloader.read_file(s3_uri=s3_uris[0], session=sagemaker_session)
70+
assert file_2_body == S3Downloader.read_file(s3_uri=s3_uris[1], session=sagemaker_session)
71+
72+
73+
def test_s3_uploader_and_downloader_downloads_files_when_given_file_name_uris(
74+
sagemaker_session, s3_files_kms_key
75+
):
76+
my_uuid = str(uuid.uuid4())
77+
78+
file_1_body = "First File Body {}.".format(my_uuid)
79+
file_1_name = "first_file_{}.txt".format(my_uuid)
80+
file_2_body = "Second File Body {}.".format(my_uuid)
81+
file_2_name = "second_file_{}.txt".format(my_uuid)
82+
83+
base_s3_uri = os.path.join(
84+
"s3://", sagemaker_session.default_bucket(), "integ-test-test-s3-list", my_uuid
85+
)
86+
file_1_s3_uri = os.path.join(base_s3_uri, file_1_name)
87+
file_2_s3_uri = os.path.join(base_s3_uri, file_2_name)
88+
89+
S3Uploader.upload_string_as_file_body(
90+
body=file_1_body,
91+
desired_s3_uri=file_1_s3_uri,
92+
kms_key=s3_files_kms_key,
93+
session=sagemaker_session,
94+
)
95+
96+
S3Uploader.upload_string_as_file_body(
97+
body=file_2_body,
98+
desired_s3_uri=file_2_s3_uri,
99+
kms_key=s3_files_kms_key,
100+
session=sagemaker_session,
101+
)
102+
103+
s3_uris = S3Downloader.list(s3_uri=base_s3_uri, session=sagemaker_session)
104+
105+
assert file_1_name in s3_uris[0]
106+
assert file_2_name in s3_uris[1]
107+
108+
S3Downloader.download(s3_uri=s3_uris[0], local_path=TMP_BASE_PATH, session=sagemaker_session)
109+
S3Downloader.download(s3_uri=s3_uris[1], local_path=TMP_BASE_PATH, session=sagemaker_session)
110+
111+
with open(os.path.join(TMP_BASE_PATH, file_1_name), "r") as f:
112+
assert file_1_body == f.read()
38113

114+
with open(os.path.join(TMP_BASE_PATH, file_2_name), "r") as f:
115+
assert file_2_body == f.read()
116+
117+
118+
def test_s3_uploader_and_downloader_downloads_files_when_given_directory_uris_with_files(
119+
sagemaker_session, s3_files_kms_key
120+
):
39121
my_uuid = str(uuid.uuid4())
40122

123+
file_1_body = "First File Body {}.".format(my_uuid)
124+
file_1_name = "first_file_{}.txt".format(my_uuid)
125+
file_2_body = "Second File Body {}.".format(my_uuid)
126+
file_2_name = "second_file_{}.txt".format(my_uuid)
127+
41128
base_s3_uri = os.path.join(
42129
"s3://", sagemaker_session.default_bucket(), "integ-test-test-s3-list", my_uuid
43130
)
@@ -65,3 +152,71 @@ def test_statistics_object_creation_from_s3_uri_with_customizations(
65152

66153
assert file_1_body == S3Downloader.read_file(s3_uri=s3_uris[0], session=sagemaker_session)
67154
assert file_2_body == S3Downloader.read_file(s3_uri=s3_uris[1], session=sagemaker_session)
155+
156+
S3Downloader.download(s3_uri=base_s3_uri, local_path=TMP_BASE_PATH, session=sagemaker_session)
157+
158+
with open(os.path.join(TMP_BASE_PATH, file_1_name), "r") as f:
159+
assert file_1_body == f.read()
160+
161+
with open(os.path.join(TMP_BASE_PATH, file_2_name), "r") as f:
162+
assert file_2_body == f.read()
163+
164+
165+
def test_s3_uploader_and_downloader_downloads_files_when_given_directory_uris_with_directory(
166+
sagemaker_session, s3_files_kms_key
167+
):
168+
my_uuid = str(uuid.uuid4())
169+
my_inner_directory_uuid = str(uuid.uuid4())
170+
171+
file_1_body = "First File Body {}.".format(my_uuid)
172+
file_1_name = "first_file_{}.txt".format(my_uuid)
173+
file_2_body = "Second File Body {}.".format(my_uuid)
174+
file_2_name = "second_file_{}.txt".format(my_uuid)
175+
176+
base_s3_uri = os.path.join(
177+
"s3://",
178+
sagemaker_session.default_bucket(),
179+
"integ-test-test-s3-list",
180+
my_uuid,
181+
my_inner_directory_uuid,
182+
)
183+
file_1_s3_uri = os.path.join(base_s3_uri, file_1_name)
184+
file_2_s3_uri = os.path.join(base_s3_uri, file_2_name)
185+
186+
S3Uploader.upload_string_as_file_body(
187+
body=file_1_body,
188+
desired_s3_uri=file_1_s3_uri,
189+
kms_key=s3_files_kms_key,
190+
session=sagemaker_session,
191+
)
192+
193+
S3Uploader.upload_string_as_file_body(
194+
body=file_2_body,
195+
desired_s3_uri=file_2_s3_uri,
196+
kms_key=s3_files_kms_key,
197+
session=sagemaker_session,
198+
)
199+
200+
s3_uris = S3Downloader.list(s3_uri=base_s3_uri, session=sagemaker_session)
201+
202+
assert file_1_name in s3_uris[0]
203+
assert file_2_name in s3_uris[1]
204+
205+
assert file_1_body == S3Downloader.read_file(s3_uri=s3_uris[0], session=sagemaker_session)
206+
assert file_2_body == S3Downloader.read_file(s3_uri=s3_uris[1], session=sagemaker_session)
207+
208+
s3_directory_with_directory_underneath = os.path.join(
209+
"s3://", sagemaker_session.default_bucket(), "integ-test-test-s3-list", my_uuid
210+
)
211+
212+
S3Downloader.download(
213+
s3_uri=s3_directory_with_directory_underneath,
214+
local_path=TMP_BASE_PATH,
215+
session=sagemaker_session,
216+
)
217+
218+
with open(os.path.join(TMP_BASE_PATH, my_inner_directory_uuid, file_1_name), "r") as f:
219+
assert file_1_body == f.read()
220+
221+
with open(os.path.join(TMP_BASE_PATH, my_inner_directory_uuid, file_2_name), "r") as f:
222+
assert file_2_body == f.read()

0 commit comments

Comments
 (0)