Skip to content

Commit 0d2a0af

Browse files
authored
Merge pull request #27 from stijndehaes/feature/partition_columns
Partition columns now have correct type
2 parents 98f0dad + 74a329f commit 0d2a0af

File tree

2 files changed

+69
-26
lines changed

2 files changed

+69
-26
lines changed

awswrangler/glue.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,10 @@ def metadata_to_glue(self,
135135
mode="append",
136136
cast_columns=None,
137137
extra_args=None):
138-
schema = Glue._build_schema(dataframe=dataframe,
139-
partition_cols=partition_cols,
140-
preserve_index=preserve_index)
138+
schema, partition_cols_schema = Glue._build_schema(
139+
dataframe=dataframe,
140+
partition_cols=partition_cols,
141+
preserve_index=preserve_index)
141142
table = table if table else Glue._parse_table_name(path)
142143
table = table.lower().replace(".", "_")
143144
if mode == "overwrite":
@@ -147,7 +148,7 @@ def metadata_to_glue(self,
147148
self.create_table(database=database,
148149
table=table,
149150
schema=schema,
150-
partition_cols=partition_cols,
151+
partition_cols_schema=partition_cols_schema,
151152
path=path,
152153
file_format=file_format,
153154
extra_args=extra_args)
@@ -180,14 +181,13 @@ def create_table(self,
180181
schema,
181182
path,
182183
file_format,
183-
partition_cols=None,
184+
partition_cols_schema=None,
184185
extra_args=None):
185186
if file_format == "parquet":
186-
table_input = Glue.parquet_table_definition(
187-
table, partition_cols, schema, path)
187+
table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path)
188188
elif file_format == "csv":
189189
table_input = Glue.csv_table_definition(table,
190-
partition_cols,
190+
partition_cols_schema,
191191
schema,
192192
path,
193193
extra_args=extra_args)
@@ -248,13 +248,20 @@ def _build_schema(dataframe, partition_cols, preserve_index):
248248
dataframe=dataframe, preserve_index=preserve_index)
249249

250250
schema_built = []
251+
partition_cols_types = {}
251252
for name, dtype in pyarrow_schema:
252-
if name not in partition_cols:
253-
athena_type = Glue.type_pyarrow2athena(dtype)
253+
athena_type = Glue.type_pyarrow2athena(dtype)
254+
if name in partition_cols:
255+
partition_cols_types[name] = athena_type
256+
else:
254257
schema_built.append((name, athena_type))
255258

259+
partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols]
260+
256261
logger.debug(f"schema_built:\n{schema_built}")
257-
return schema_built
262+
logger.debug(
263+
f"partition_cols_schema_built:\n{partition_cols_schema_built}")
264+
return schema_built, partition_cols_schema_built
258265

259266
@staticmethod
260267
def _parse_table_name(path):
@@ -263,17 +270,17 @@ def _parse_table_name(path):
263270
return path.rpartition("/")[2]
264271

265272
@staticmethod
266-
def csv_table_definition(table, partition_cols, schema, path, extra_args):
273+
def csv_table_definition(table, partition_cols_schema, schema, path, extra_args):
267274
sep = extra_args["sep"] if "sep" in extra_args else ","
268-
if not partition_cols:
269-
partition_cols = []
275+
if not partition_cols_schema:
276+
partition_cols_schema = []
270277
return {
271278
"Name":
272279
table,
273280
"PartitionKeys": [{
274-
"Name": x,
275-
"Type": "string"
276-
} for x in partition_cols],
281+
"Name": x[0],
282+
"Type": x[1]
283+
} for x in partition_cols_schema],
277284
"TableType":
278285
"EXTERNAL_TABLE",
279286
"Parameters": {
@@ -334,16 +341,17 @@ def csv_partition_definition(partition):
334341
}
335342

336343
@staticmethod
337-
def parquet_table_definition(table, partition_cols, schema, path):
338-
if not partition_cols:
339-
partition_cols = []
344+
def parquet_table_definition(table, partition_cols_schema,
345+
schema, path):
346+
if not partition_cols_schema:
347+
partition_cols_schema = []
340348
return {
341349
"Name":
342350
table,
343351
"PartitionKeys": [{
344-
"Name": x,
345-
"Type": "string"
346-
} for x in partition_cols],
352+
"Name": x[0],
353+
"Type": x[1]
354+
} for x in partition_cols_schema],
347355
"TableType":
348356
"EXTERNAL_TABLE",
349357
"Parameters": {

testing/test_awswrangler/test_pandas.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,8 +491,41 @@ def test_to_csv_with_sep(
491491
dataframe2["id"] == 0].iloc[0]["name"]
492492

493493

494-
@pytest.mark.parametrize("index", [None, "default", "my_date", "my_timestamp"])
495-
def test_to_parquet_types(session, bucket, database, index):
494+
@pytest.mark.parametrize("index, partition_cols", [
495+
(None, []),
496+
("default", []),
497+
("my_date", []),
498+
("my_timestamp", []),
499+
(None, ["my_int"]),
500+
("default", ["my_int"]),
501+
("my_date", ["my_int"]),
502+
("my_timestamp", ["my_int"]),
503+
(None, ["my_float"]),
504+
("default", ["my_float"]),
505+
("my_date", ["my_float"]),
506+
("my_timestamp", ["my_float"]),
507+
(None, ["my_bool"]),
508+
("default", ["my_bool"]),
509+
("my_date", ["my_bool"]),
510+
("my_timestamp", ["my_bool"]),
511+
(None, ["my_date"]),
512+
("default", ["my_date"]),
513+
("my_date", ["my_date"]),
514+
("my_timestamp", ["my_date"]),
515+
(None, ["my_timestamp"]),
516+
("default", ["my_timestamp"]),
517+
("my_date", ["my_timestamp"]),
518+
("my_timestamp", ["my_timestamp"]),
519+
(None, ["my_timestamp", "my_date"]),
520+
("default", ["my_date", "my_timestamp"]),
521+
("my_date", ["my_timestamp", "my_date"]),
522+
("my_timestamp", ["my_date", "my_timestamp"]),
523+
(None, ["my_bool", "my_timestamp", "my_date"]),
524+
("default", ["my_date", "my_timestamp", "my_int"]),
525+
("my_date", ["my_timestamp", "my_float", "my_date"]),
526+
("my_timestamp", ["my_int", "my_float", "my_bool", "my_date", "my_timestamp"]),
527+
])
528+
def test_to_parquet_types(session, bucket, database, index, partition_cols):
496529
dataframe = pandas.read_csv("data_samples/complex.csv",
497530
dtype={"my_int_with_null": "Int64"},
498531
parse_dates=["my_timestamp", "my_date"])
@@ -510,6 +543,7 @@ def test_to_parquet_types(session, bucket, database, index):
510543
database=database,
511544
path=f"s3://{bucket}/test/",
512545
preserve_index=preserve_index,
546+
partition_cols=partition_cols,
513547
mode="overwrite",
514548
procs_cpu_bound=1)
515549
sleep(1)
@@ -518,7 +552,8 @@ def test_to_parquet_types(session, bucket, database, index):
518552
for row in dataframe2.itertuples():
519553
if index:
520554
if index == "default":
521-
assert isinstance(row[8], numpy.int64)
555+
ex_index_col = 8 - len(partition_cols)
556+
assert isinstance(row[ex_index_col], numpy.int64)
522557
elif index == "my_date":
523558
assert isinstance(row.new_index, date)
524559
elif index == "my_timestamp":

0 commit comments

Comments
 (0)