Skip to content

Commit cebf595

Browse files
committed
Athena tests refactoring.
1 parent 2f98ec0 commit cebf595

File tree

7 files changed

+1314
-1294
lines changed

7 files changed

+1314
-1294
lines changed

tests/_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from decimal import Decimal
55

66
import boto3
7+
import botocore.exceptions
78
import pandas as pd
89

910
import awswrangler as wr
@@ -470,3 +471,17 @@ def extract_cloudformation_outputs():
470471
for output in stack.get("Outputs"):
471472
outputs[output.get("OutputKey")] = output.get("OutputValue")
472473
return outputs
474+
475+
476+
def list_workgroups():
477+
client = boto3.client("athena")
478+
attempt = 1
479+
while True:
480+
try:
481+
return client.list_work_groups()
482+
except botocore.exceptions.ClientError as ex:
483+
if ex.response["Error"]["Code"] != "ThrottlingException":
484+
raise ex
485+
if attempt > 5:
486+
raise ex
487+
time.sleep(attempt + random.randrange(start=0, stop=3, step=1))

tests/conftest.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import awswrangler as wr
77

8-
from ._utils import extract_cloudformation_outputs, get_time_str_with_random_suffix, path_generator
8+
from ._utils import extract_cloudformation_outputs, get_time_str_with_random_suffix, list_workgroups, path_generator
99

1010

1111
@pytest.fixture(scope="session")
@@ -66,7 +66,7 @@ def loggroup(cloudformation_outputs):
6666
def workgroup0(bucket):
6767
wkg_name = "aws_data_wrangler_0"
6868
client = boto3.client("athena")
69-
wkgs = client.list_work_groups()
69+
wkgs = list_workgroups()
7070
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
7171
if wkg_name not in wkgs:
7272
client.create_work_group(
@@ -87,7 +87,7 @@ def workgroup0(bucket):
8787
def workgroup1(bucket):
8888
wkg_name = "aws_data_wrangler_1"
8989
client = boto3.client("athena")
90-
wkgs = client.list_work_groups()
90+
wkgs = list_workgroups()
9191
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
9292
if wkg_name not in wkgs:
9393
client.create_work_group(
@@ -111,7 +111,7 @@ def workgroup1(bucket):
111111
def workgroup2(bucket, kms_key):
112112
wkg_name = "aws_data_wrangler_2"
113113
client = boto3.client("athena")
114-
wkgs = client.list_work_groups()
114+
wkgs = list_workgroups()
115115
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
116116
if wkg_name not in wkgs:
117117
client.create_work_group(
@@ -135,7 +135,7 @@ def workgroup2(bucket, kms_key):
135135
def workgroup3(bucket, kms_key):
136136
wkg_name = "aws_data_wrangler_3"
137137
client = boto3.client("athena")
138-
wkgs = client.list_work_groups()
138+
wkgs = list_workgroups()
139139
wkgs = [x["Name"] for x in wkgs["WorkGroups"]]
140140
if wkg_name not in wkgs:
141141
client.create_work_group(
@@ -199,6 +199,7 @@ def glue_table(glue_database):
199199
wr.catalog.delete_table_if_exists(database=glue_database, table=name)
200200
yield name
201201
wr.catalog.delete_table_if_exists(database=glue_database, table=name)
202+
print(f"Table {glue_database}.{name} deleted.")
202203

203204

204205
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)