Skip to content

Commit 6bf8ab6

Browse files
committed
Adding support for Date type for pandas integrations
1 parent 4c95e87 commit 6bf8ab6

File tree

4 files changed

+118
-28
lines changed

4 files changed

+118
-28
lines changed

awswrangler/athena.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,30 @@ def _type_athena2pandas(dtype):
3131
return "bool"
3232
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
3333
return "object"
34-
elif dtype in ["timestamp", "date"]:
34+
elif dtype == "timestamp":
3535
return "datetime64"
36+
elif dtype == "date":
37+
return "date"
3638
else:
3739
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
3840

3941
def get_query_dtype(self, query_execution_id):
4042
cols_metadata = self.get_query_columns_metadata(
4143
query_execution_id=query_execution_id)
4244
dtype = {}
45+
parse_timestamps = []
4346
parse_dates = []
4447
for col_name, col_type in cols_metadata.items():
4548
ptype = Athena._type_athena2pandas(dtype=col_type)
46-
if ptype == "datetime64":
47-
parse_dates.append(col_name)
49+
if ptype in ["datetime64", "date"]:
50+
parse_timestamps.append(col_name)
51+
if ptype == "date":
52+
parse_dates.append(col_name)
4853
else:
4954
dtype[col_name] = ptype
5055
logger.debug(f"dtype: {dtype}")
5156
logger.debug(f"parse_dates: {parse_dates}")
52-
return dtype, parse_dates
57+
return dtype, parse_timestamps, parse_dates
5358

5459
def create_athena_bucket(self):
5560
"""

awswrangler/glue.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44
from datetime import datetime, date
55

6+
import pyarrow
7+
68
from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat
79

810
logger = logging.getLogger(__name__)
@@ -43,6 +45,28 @@ def get_table_python_types(self, database, table):
4345
dtypes = self.get_table_athena_types(database=database, table=table)
4446
return {k: Glue.type_athena2python(v) for k, v in dtypes.items()}
4547

48+
@staticmethod
49+
def type_pyarrow2athena(dtype):
50+
dtype = str(dtype).lower()
51+
if dtype == "int32":
52+
return "int"
53+
elif dtype == "int64":
54+
return "bigint"
55+
elif dtype == "float":
56+
return "float"
57+
elif dtype == "double":
58+
return "double"
59+
elif dtype == "bool":
60+
return "boolean"
61+
elif dtype == "string":
62+
return "string"
63+
elif dtype.startswith("timestamp"):
64+
return "timestamp"
65+
elif dtype.startswith("date"):
66+
return "date"
67+
else:
68+
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
69+
4670
@staticmethod
4771
def type_pandas2athena(dtype):
4872
dtype = dtype.lower()
@@ -58,7 +82,7 @@ def type_pandas2athena(dtype):
5882
return "boolean"
5983
elif dtype == "object":
6084
return "string"
61-
elif dtype[:10] == "datetime64":
85+
elif dtype.startswith("datetime64"):
6286
return "timestamp"
6387
else:
6488
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
@@ -113,8 +137,7 @@ def metadata_to_glue(self,
113137
extra_args=None):
114138
schema = Glue._build_schema(dataframe=dataframe,
115139
partition_cols=partition_cols,
116-
preserve_index=preserve_index,
117-
cast_columns=cast_columns)
140+
preserve_index=preserve_index)
118141
table = table if table else Glue._parse_table_name(path)
119142
table = table.lower().replace(".", "_")
120143
if mode == "overwrite":
@@ -198,31 +221,38 @@ def get_connection_details(self, name):
198221
Name=name, HidePassword=False)["Connection"]
199222

200223
@staticmethod
201-
def _build_schema(dataframe,
202-
partition_cols,
203-
preserve_index,
204-
cast_columns=None):
224+
def _extract_pyarrow_schema(dataframe, preserve_index):
225+
cols = []
226+
schema = []
227+
for name, dtype in dataframe.dtypes.to_dict().items():
228+
dtype = str(dtype)
229+
if str(dtype) == "Int64":
230+
schema.append((name, "int64"))
231+
else:
232+
cols.append(name)
233+
234+
# Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
235+
schema += [(str(x.name), str(x.type))
236+
for x in pyarrow.Schema.from_pandas(
237+
df=dataframe[cols], preserve_index=preserve_index)]
238+
logger.debug(f"schema: {schema}")
239+
return schema
240+
241+
@staticmethod
242+
def _build_schema(dataframe, partition_cols, preserve_index):
205243
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
206244
if not partition_cols:
207245
partition_cols = []
246+
247+
pyarrow_schema = Glue._extract_pyarrow_schema(
248+
dataframe=dataframe, preserve_index=preserve_index)
249+
208250
schema_built = []
209-
if preserve_index:
210-
name = str(
211-
dataframe.index.name) if dataframe.index.name else "index"
212-
dataframe.index.name = "index"
213-
dtype = str(dataframe.index.dtype)
214-
if name not in partition_cols:
215-
athena_type = Glue.type_pandas2athena(dtype)
216-
schema_built.append((name, athena_type))
217-
for col in dataframe.columns:
218-
name = str(col)
219-
if cast_columns and name in cast_columns:
220-
dtype = cast_columns[name]
221-
else:
222-
dtype = str(dataframe[name].dtype)
251+
for name, dtype in pyarrow_schema:
223252
if name not in partition_cols:
224-
athena_type = Glue.type_pandas2athena(dtype)
253+
athena_type = Glue.type_pyarrow2athena(dtype)
225254
schema_built.append((name, athena_type))
255+
226256
logger.debug(f"schema_built:\n{schema_built}")
227257
return schema_built
228258

awswrangler/pandas.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,16 @@ def read_sql_athena(self,
419419
message_error = f"Query error: {reason}"
420420
raise AthenaQueryError(message_error)
421421
else:
422-
dtype, parse_dates = self._session.athena.get_query_dtype(
422+
dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype(
423423
query_execution_id=query_execution_id)
424424
path = f"{s3_output}{query_execution_id}.csv"
425425
ret = self.read_csv(path=path,
426426
dtype=dtype,
427-
parse_dates=parse_dates,
427+
parse_dates=parse_timestamps,
428428
quoting=csv.QUOTE_ALL,
429429
max_result_size=max_result_size)
430+
for col in parse_dates:
431+
ret[col] = ret[col].dt.date
430432
return ret
431433

432434
def to_csv(

testing/test_awswrangler/test_pandas.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,56 @@ def test_to_csv_with_sep(
489489
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
490490
assert dataframe[dataframe["id"] == 0].iloc[0]["name"] == dataframe2[
491491
dataframe2["id"] == 0].iloc[0]["name"]
492+
493+
494+
@pytest.mark.parametrize("index", [None, "default", "my_date", "my_timestamp"])
495+
def test_to_parquet_types(session, bucket, database, index):
496+
dataframe = pandas.read_csv("data_samples/complex.csv",
497+
dtype={"my_int_with_null": "Int64"},
498+
parse_dates=["my_timestamp", "my_date"])
499+
dataframe["my_date"] = dataframe["my_date"].dt.date
500+
dataframe["my_bool"] = True
501+
502+
preserve_index = True
503+
if not index:
504+
preserve_index = False
505+
elif index != "default":
506+
dataframe["new_index"] = dataframe[index]
507+
dataframe = dataframe.set_index("new_index")
508+
509+
session.pandas.to_parquet(dataframe=dataframe,
510+
database=database,
511+
path=f"s3://{bucket}/test/",
512+
preserve_index=preserve_index,
513+
mode="overwrite",
514+
procs_cpu_bound=1)
515+
sleep(1)
516+
dataframe2 = session.pandas.read_sql_athena(sql="select * from test",
517+
database=database)
518+
for row in dataframe2.itertuples():
519+
if index:
520+
if index == "default":
521+
assert isinstance(row[8], numpy.int64)
522+
elif index == "my_date":
523+
assert isinstance(row.new_index, date)
524+
elif index == "my_timestamp":
525+
assert isinstance(row.new_index, datetime)
526+
assert isinstance(row.my_timestamp, datetime)
527+
assert type(row.my_date) == date
528+
assert isinstance(row.my_float, float)
529+
assert isinstance(row.my_int, numpy.int64)
530+
assert isinstance(row.my_string, str)
531+
assert isinstance(row.my_bool, bool)
532+
assert str(row.my_timestamp) == "2018-01-01 04:03:02.001000"
533+
assert str(row.my_date) == "2019-02-02"
534+
assert str(row.my_float) == "12345.6789"
535+
assert str(row.my_int) == "123456789"
536+
assert str(row.my_bool) == "True"
537+
assert str(
538+
row.my_string
539+
) == "foo\nboo\nbar\nFOO\nBOO\nBAR\nxxxxx\nÁÃÀÂÇ\n汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå"
540+
assert len(dataframe.index) == len(dataframe2.index)
541+
if index:
542+
assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns))
543+
else:
544+
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))

0 commit comments

Comments
 (0)