@@ -107,8 +107,8 @@ def __init__(
107
107
"""
108
108
self ._default_bucket = None
109
109
self ._default_bucket_name_override = default_bucket
110
-
111
- # currently is used for local_code in local mode
110
+ self . s3_resource = None
111
+ self . s3_client = None
112
112
self .config = None
113
113
114
114
self ._initialize (
@@ -199,7 +199,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None):
199
199
key_suffix = name
200
200
201
201
bucket = bucket or self .default_bucket ()
202
- s3 = self .boto_session .resource ("s3" )
202
+ if self .s3_resource is None :
203
+ s3 = self .boto_session .resource ("s3" , region_name = self .boto_region_name )
204
+ else :
205
+ s3 = self .s3_resource
203
206
204
207
for local_path , s3_key in files :
205
208
s3 .Object (bucket , s3_key ).upload_file (local_path , ExtraArgs = extra_args )
@@ -227,7 +230,11 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
227
230
str: The S3 URI of the uploaded file.
228
231
The URI format is: ``s3://{bucket name}/{key}``.
229
232
"""
230
- s3 = self .boto_session .resource ("s3" )
233
+ if self .s3_resource is None :
234
+ s3 = self .boto_session .resource ("s3" , region_name = self .boto_region_name )
235
+ else :
236
+ s3 = self .s3_resource
237
+
231
238
s3_object = s3 .Object (bucket_name = bucket , key = key )
232
239
233
240
if kms_key is not None :
@@ -254,7 +261,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
254
261
255
262
"""
256
263
# Initialize the S3 client.
257
- s3 = self .boto_session .client ("s3" )
264
+ if self .s3_client is None :
265
+ s3 = self .boto_session .client ("s3" , region_name = self .boto_region_name )
266
+ else :
267
+ s3 = self .s3_client
258
268
259
269
# Initialize the variables used to loop through the contents of the S3 bucket.
260
270
keys = []
@@ -299,7 +309,10 @@ def read_s3_file(self, bucket, key_prefix):
299
309
str: The body of the s3 file as a string.
300
310
301
311
"""
302
- s3 = self .boto_session .client ("s3" )
312
+ if self .s3_client is None :
313
+ s3 = self .boto_session .client ("s3" , region_name = self .boto_region_name )
314
+ else :
315
+ s3 = self .s3_client
303
316
304
317
# Explicitly passing a None kms_key to boto3 throws a validation error.
305
318
s3_object = s3 .get_object (Bucket = bucket , Key = key_prefix )
@@ -317,7 +330,10 @@ def list_s3_files(self, bucket, key_prefix):
317
330
[str]: The list of files at the S3 path.
318
331
319
332
"""
320
- s3 = self .boto_session .resource ("s3" )
333
+ if self .s3_resource is None :
334
+ s3 = self .boto_session .resource ("s3" , region_name = self .boto_region_name )
335
+ else :
336
+ s3 = self .s3_resource
321
337
322
338
s3_bucket = s3 .Bucket (name = bucket )
323
339
s3_objects = s3_bucket .objects .filter (Prefix = key_prefix ).all ()
@@ -330,6 +346,7 @@ def default_bucket(self):
330
346
str: The name of the default bucket, which is of the form:
331
347
``sagemaker-{region}-{AWS account ID}``.
332
348
"""
349
+
333
350
if self ._default_bucket :
334
351
return self ._default_bucket
335
352
@@ -364,10 +381,14 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
364
381
already being created, no exception is raised.
365
382
366
383
"""
367
- bucket = self .boto_session .resource ("s3" , region_name = region ).Bucket (name = bucket_name )
384
+ if self .s3_resource is None :
385
+ s3 = self .boto_session .resource ("s3" , region_name = region )
386
+ else :
387
+ s3 = self .s3_resource
388
+
389
+ bucket = s3 .Bucket (name = bucket_name )
368
390
if bucket .creation_date is None :
369
391
try :
370
- s3 = self .boto_session .resource ("s3" , region_name = region )
371
392
if region == "us-east-1" :
372
393
# 'us-east-1' cannot be specified because it is the default region:
373
394
# https://github.com/boto/boto3/issues/125
0 commit comments