Skip to content

Commit 777b57f

Browse files
authored
Adding preserve_file_name param to S3Hook.download_file method (#26886)
* Adding `preserve_file_name` param to `S3Hook.download_file` method
1 parent d544e8f commit 777b57f

File tree

2 files changed

+123
-14
lines changed
  • airflow/providers/amazon/aws/hooks
  • tests/providers/amazon/aws/hooks

2 files changed

+123
-14
lines changed

airflow/providers/amazon/aws/hooks/s3.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from inspect import signature
3030
from io import BytesIO
3131
from pathlib import Path
32-
from tempfile import NamedTemporaryFile
32+
from tempfile import NamedTemporaryFile, gettempdir
3333
from typing import Any, Callable, List, TypeVar, cast
3434
from urllib.parse import urlparse
35+
from uuid import uuid4
3536

3637
from boto3.s3.transfer import S3Transfer, TransferConfig
3738
from botocore.exceptions import ClientError
@@ -879,17 +880,38 @@ def delete_objects(self, bucket: str, keys: str | list) -> None:
879880

880881
@provide_bucket_name
881882
@unify_bucket_name_and_key
882-
def download_file(self, key: str, bucket_name: str | None = None, local_path: str | None = None) -> str:
883+
def download_file(
884+
self,
885+
key: str,
886+
bucket_name: str | None = None,
887+
local_path: str | None = None,
888+
preserve_file_name: bool = False,
889+
use_autogenerated_subdir: bool = True,
890+
) -> str:
883891
"""
884892
Downloads a file from the S3 location to the local file system.
885893
886894
:param key: The key path in S3.
887895
:param bucket_name: The specific bucket to use.
888896
:param local_path: The local path to the downloaded file. If no path is provided it will use the
889897
system's temporary directory.
898+
:param preserve_file_name: If you want the downloaded file name to be the same name as it is in S3,
899+
set this parameter to True. When set to False, a random filename will be generated.
900+
Default: False.
901+
:param use_autogenerated_subdir: Pairs with 'preserve_file_name = True' to download the file into a
902+
random generated folder inside the 'local_path', useful to avoid collisions between various tasks
903+
that might download the same file name. Set it to 'False' if you don't want it, and you want a
904+
predictable path.
905+
Default: True.
890906
:return: the file name.
891907
:rtype: str
892908
"""
909+
self.log.info(
910+
"This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
911+
"want to use the original method from S3 API, please call "
912+
"'S3Hook.get_conn().download_file()'"
913+
)
914+
893915
self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
894916

895917
try:
@@ -902,14 +924,30 @@ def download_file(self, key: str, bucket_name: str | None = None, local_path: st
902924
else:
903925
raise e
904926

905-
with NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) as local_tmp_file:
927+
if preserve_file_name:
928+
local_dir = local_path if local_path else gettempdir()
929+
subdir = f"airflow_tmp_dir_{uuid4().hex[0:8]}" if use_autogenerated_subdir else ""
930+
filename_in_s3 = s3_obj.key.rsplit("/", 1)[-1]
931+
file_path = Path(local_dir, subdir, filename_in_s3)
932+
933+
if file_path.is_file():
934+
self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path)
935+
raise FileExistsError
936+
937+
file_path.parent.mkdir(exist_ok=True, parents=True)
938+
939+
file = open(file_path, "wb")
940+
else:
941+
file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
942+
943+
with file:
906944
s3_obj.download_fileobj(
907-
local_tmp_file,
945+
file,
908946
ExtraArgs=self.extra_args,
909947
Config=self.transfer_config,
910948
)
911949

912-
return local_tmp_file.name
950+
return file.name
913951

914952
def generate_presigned_url(
915953
self,

tests/providers/amazon/aws/hooks/test_s3.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import gzip as gz
2121
import os
2222
import tempfile
23+
from pathlib import Path
2324
from unittest import mock
2425
from unittest.mock import Mock
2526

@@ -532,24 +533,94 @@ def test_function_with_test_key(self, test_key, bucket_name=None):
532533

533534
@mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
534535
def test_download_file(self, mock_temp_file):
535-
mock_temp_file.return_value.__enter__ = Mock(return_value=mock_temp_file)
536+
with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as temp_file:
537+
mock_temp_file.return_value = temp_file
538+
s3_hook = S3Hook(aws_conn_id="s3_test")
539+
s3_hook.check_for_key = Mock(return_value=True)
540+
s3_obj = Mock()
541+
s3_obj.download_fileobj = Mock(return_value=None)
542+
s3_hook.get_key = Mock(return_value=s3_obj)
543+
key = "test_key"
544+
bucket = "test_bucket"
545+
546+
output_file = s3_hook.download_file(key=key, bucket_name=bucket)
547+
548+
s3_hook.get_key.assert_called_once_with(key, bucket)
549+
s3_obj.download_fileobj.assert_called_once_with(
550+
temp_file,
551+
Config=s3_hook.transfer_config,
552+
ExtraArgs=s3_hook.extra_args,
553+
)
554+
555+
assert temp_file.name == output_file
556+
557+
@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
558+
def test_download_file_with_preserve_name(self, mock_open):
559+
file_name = "test.log"
560+
bucket = "test_bucket"
561+
key = f"test_key/{file_name}"
562+
local_folder = "/tmp"
563+
536564
s3_hook = S3Hook(aws_conn_id="s3_test")
537565
s3_hook.check_for_key = Mock(return_value=True)
538566
s3_obj = Mock()
567+
s3_obj.key = f"s3://{bucket}/{key}"
539568
s3_obj.download_fileobj = Mock(return_value=None)
540569
s3_hook.get_key = Mock(return_value=s3_obj)
541-
key = "test_key"
542-
bucket = "test_bucket"
570+
s3_hook.download_file(
571+
key=key,
572+
bucket_name=bucket,
573+
local_path=local_folder,
574+
preserve_file_name=True,
575+
use_autogenerated_subdir=False,
576+
)
543577

544-
s3_hook.download_file(key=key, bucket_name=bucket)
578+
mock_open.assert_called_once_with(Path(local_folder, file_name), "wb")
545579

546-
s3_hook.get_key.assert_called_once_with(key, bucket)
547-
s3_obj.download_fileobj.assert_called_once_with(
548-
mock_temp_file,
549-
Config=s3_hook.transfer_config,
550-
ExtraArgs=s3_hook.extra_args,
580+
@mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
581+
def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open):
582+
file_name = "test.log"
583+
bucket = "test_bucket"
584+
key = f"test_key/{file_name}"
585+
local_folder = "/tmp"
586+
587+
s3_hook = S3Hook(aws_conn_id="s3_test")
588+
s3_hook.check_for_key = Mock(return_value=True)
589+
s3_obj = Mock()
590+
s3_obj.key = f"s3://{bucket}/{key}"
591+
s3_obj.download_fileobj = Mock(return_value=None)
592+
s3_hook.get_key = Mock(return_value=s3_obj)
593+
result_file = s3_hook.download_file(
594+
key=key,
595+
bucket_name=bucket,
596+
local_path=local_folder,
597+
preserve_file_name=True,
598+
use_autogenerated_subdir=True,
551599
)
552600

601+
assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_")
602+
603+
def test_download_file_with_preserve_name_file_already_exists(self):
604+
with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as file:
605+
file_name = file.name.rsplit("/", 1)[-1]
606+
bucket = "test_bucket"
607+
key = f"test_key/{file_name}"
608+
local_folder = "/tmp"
609+
s3_hook = S3Hook(aws_conn_id="s3_test")
610+
s3_hook.check_for_key = Mock(return_value=True)
611+
s3_obj = Mock()
612+
s3_obj.key = f"s3://{bucket}/{key}"
613+
s3_obj.download_fileobj = Mock(return_value=None)
614+
s3_hook.get_key = Mock(return_value=s3_obj)
615+
with pytest.raises(FileExistsError):
616+
s3_hook.download_file(
617+
key=key,
618+
bucket_name=bucket,
619+
local_path=local_folder,
620+
preserve_file_name=True,
621+
use_autogenerated_subdir=False,
622+
)
623+
553624
def test_generate_presigned_url(self, s3_bucket):
554625
hook = S3Hook()
555626
presigned_url = hook.generate_presigned_url(

0 commit comments

Comments
 (0)