|
20 | 20 | import gzip as gz
|
21 | 21 | import os
|
22 | 22 | import tempfile
|
| 23 | +from pathlib import Path |
23 | 24 | from unittest import mock
|
24 | 25 | from unittest.mock import Mock
|
25 | 26 |
|
@@ -532,24 +533,94 @@ def test_function_with_test_key(self, test_key, bucket_name=None):
|
532 | 533 |
|
533 | 534 | @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
|
534 | 535 | def test_download_file(self, mock_temp_file):
|
535 |
| - mock_temp_file.return_value.__enter__ = Mock(return_value=mock_temp_file) |
| 536 | + with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as temp_file: |
| 537 | + mock_temp_file.return_value = temp_file |
| 538 | + s3_hook = S3Hook(aws_conn_id="s3_test") |
| 539 | + s3_hook.check_for_key = Mock(return_value=True) |
| 540 | + s3_obj = Mock() |
| 541 | + s3_obj.download_fileobj = Mock(return_value=None) |
| 542 | + s3_hook.get_key = Mock(return_value=s3_obj) |
| 543 | + key = "test_key" |
| 544 | + bucket = "test_bucket" |
| 545 | + |
| 546 | + output_file = s3_hook.download_file(key=key, bucket_name=bucket) |
| 547 | + |
| 548 | + s3_hook.get_key.assert_called_once_with(key, bucket) |
| 549 | + s3_obj.download_fileobj.assert_called_once_with( |
| 550 | + temp_file, |
| 551 | + Config=s3_hook.transfer_config, |
| 552 | + ExtraArgs=s3_hook.extra_args, |
| 553 | + ) |
| 554 | + |
| 555 | + assert temp_file.name == output_file |
| 556 | + |
| 557 | + @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") |
| 558 | + def test_download_file_with_preserve_name(self, mock_open): |
| 559 | + file_name = "test.log" |
| 560 | + bucket = "test_bucket" |
| 561 | + key = f"test_key/{file_name}" |
| 562 | + local_folder = "/tmp" |
| 563 | + |
536 | 564 | s3_hook = S3Hook(aws_conn_id="s3_test")
|
537 | 565 | s3_hook.check_for_key = Mock(return_value=True)
|
538 | 566 | s3_obj = Mock()
|
| 567 | + s3_obj.key = f"s3://{bucket}/{key}" |
539 | 568 | s3_obj.download_fileobj = Mock(return_value=None)
|
540 | 569 | s3_hook.get_key = Mock(return_value=s3_obj)
|
541 |
| - key = "test_key" |
542 |
| - bucket = "test_bucket" |
| 570 | + s3_hook.download_file( |
| 571 | + key=key, |
| 572 | + bucket_name=bucket, |
| 573 | + local_path=local_folder, |
| 574 | + preserve_file_name=True, |
| 575 | + use_autogenerated_subdir=False, |
| 576 | + ) |
543 | 577 |
|
544 |
| - s3_hook.download_file(key=key, bucket_name=bucket) |
| 578 | + mock_open.assert_called_once_with(Path(local_folder, file_name), "wb") |
545 | 579 |
|
546 |
| - s3_hook.get_key.assert_called_once_with(key, bucket) |
547 |
| - s3_obj.download_fileobj.assert_called_once_with( |
548 |
| - mock_temp_file, |
549 |
| - Config=s3_hook.transfer_config, |
550 |
| - ExtraArgs=s3_hook.extra_args, |
| 580 | + @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") |
| 581 | + def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open): |
| 582 | + file_name = "test.log" |
| 583 | + bucket = "test_bucket" |
| 584 | + key = f"test_key/{file_name}" |
| 585 | + local_folder = "/tmp" |
| 586 | + |
| 587 | + s3_hook = S3Hook(aws_conn_id="s3_test") |
| 588 | + s3_hook.check_for_key = Mock(return_value=True) |
| 589 | + s3_obj = Mock() |
| 590 | + s3_obj.key = f"s3://{bucket}/{key}" |
| 591 | + s3_obj.download_fileobj = Mock(return_value=None) |
| 592 | + s3_hook.get_key = Mock(return_value=s3_obj) |
| 593 | + result_file = s3_hook.download_file( |
| 594 | + key=key, |
| 595 | + bucket_name=bucket, |
| 596 | + local_path=local_folder, |
| 597 | + preserve_file_name=True, |
| 598 | + use_autogenerated_subdir=True, |
551 | 599 | )
|
552 | 600 |
|
| 601 | + assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_") |
| 602 | + |
| 603 | + def test_download_file_with_preserve_name_file_already_exists(self): |
| 604 | + with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as file: |
| 605 | + file_name = file.name.rsplit("/", 1)[-1] |
| 606 | + bucket = "test_bucket" |
| 607 | + key = f"test_key/{file_name}" |
| 608 | + local_folder = "/tmp" |
| 609 | + s3_hook = S3Hook(aws_conn_id="s3_test") |
| 610 | + s3_hook.check_for_key = Mock(return_value=True) |
| 611 | + s3_obj = Mock() |
| 612 | + s3_obj.key = f"s3://{bucket}/{key}" |
| 613 | + s3_obj.download_fileobj = Mock(return_value=None) |
| 614 | + s3_hook.get_key = Mock(return_value=s3_obj) |
| 615 | + with pytest.raises(FileExistsError): |
| 616 | + s3_hook.download_file( |
| 617 | + key=key, |
| 618 | + bucket_name=bucket, |
| 619 | + local_path=local_folder, |
| 620 | + preserve_file_name=True, |
| 621 | + use_autogenerated_subdir=False, |
| 622 | + ) |
| 623 | + |
553 | 624 | def test_generate_presigned_url(self, s3_bucket):
|
554 | 625 | hook = S3Hook()
|
555 | 626 | presigned_url = hook.generate_presigned_url(
|
|
0 commit comments