Skip to content

Commit ae3fc5a

Browse files
committed
Add SerDe options for Pandas.to_csv()
1 parent 0d2a0af commit ae3fc5a

File tree

5 files changed

+241
-135
lines changed

5 files changed

+241
-135
lines changed

awswrangler/athena.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def _type_athena2pandas(dtype):
2929
return "float64"
3030
elif dtype == "boolean":
3131
return "bool"
32-
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
33-
return "object"
32+
elif dtype in ["string", "char", "varchar"]:
33+
return "str"
3434
elif dtype == "timestamp":
3535
return "datetime64"
3636
elif dtype == "date":
@@ -53,6 +53,7 @@ def get_query_dtype(self, query_execution_id):
5353
else:
5454
dtype[col_name] = ptype
5555
logger.debug(f"dtype: {dtype}")
56+
logger.debug(f"parse_timestamps: {parse_timestamps}")
5657
logger.debug(f"parse_dates: {parse_dates}")
5758
return dtype, parse_timestamps, parse_dates
5859

awswrangler/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,11 @@ class QueryCancelled(Exception):
6464

6565
class QueryFailed(Exception):
6666
pass
67+
68+
69+
class InvalidSerDe(Exception):
70+
pass
71+
72+
73+
class ApiError(Exception):
74+
pass

awswrangler/glue.py

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pyarrow
77

8-
from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat
8+
from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat, InvalidSerDe, ApiError
99

1010
logger = logging.getLogger(__name__)
1111

@@ -155,12 +155,11 @@ def metadata_to_glue(self,
155155
if partition_cols:
156156
partitions_tuples = Glue._parse_partitions_tuples(
157157
objects_paths=objects_paths, partition_cols=partition_cols)
158-
self.add_partitions(
159-
database=database,
160-
table=table,
161-
partition_paths=partitions_tuples,
162-
file_format=file_format,
163-
)
158+
self.add_partitions(database=database,
159+
table=table,
160+
partition_paths=partitions_tuples,
161+
file_format=file_format,
162+
extra_args=extra_args)
164163

165164
def delete_table_if_exists(self, database, table):
166165
try:
@@ -184,7 +183,8 @@ def create_table(self,
184183
partition_cols_schema=None,
185184
extra_args=None):
186185
if file_format == "parquet":
187-
table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path)
186+
table_input = Glue.parquet_table_definition(
187+
table, partition_cols_schema, schema, path)
188188
elif file_format == "csv":
189189
table_input = Glue.csv_table_definition(table,
190190
partition_cols_schema,
@@ -196,25 +196,31 @@ def create_table(self,
196196
self._client_glue.create_table(DatabaseName=database,
197197
TableInput=table_input)
198198

199-
def add_partitions(self, database, table, partition_paths, file_format):
199+
def add_partitions(self, database, table, partition_paths, file_format,
200+
extra_args):
200201
if not partition_paths:
201202
return None
202203
partitions = list()
203204
for partition in partition_paths:
204205
if file_format == "parquet":
205-
partition_def = Glue.parquet_partition_definition(partition)
206+
partition_def = Glue.parquet_partition_definition(
207+
partition=partition)
206208
elif file_format == "csv":
207-
partition_def = Glue.csv_partition_definition(partition)
209+
partition_def = Glue.csv_partition_definition(
210+
partition=partition, extra_args=extra_args)
208211
else:
209212
raise UnsupportedFileFormat(file_format)
210213
partitions.append(partition_def)
211214
pages_num = int(ceil(len(partitions) / 100.0))
212215
for _ in range(pages_num):
213216
page = partitions[:100]
214217
del partitions[:100]
215-
self._client_glue.batch_create_partition(DatabaseName=database,
216-
TableName=table,
217-
PartitionInputList=page)
218+
res = self._client_glue.batch_create_partition(
219+
DatabaseName=database,
220+
TableName=table,
221+
PartitionInputList=page)
222+
if len(res["Errors"]) > 0:
223+
raise ApiError(f"{res['Errors'][0]}")
218224

219225
def get_connection_details(self, name):
220226
return self._client_glue.get_connection(
@@ -223,18 +229,25 @@ def get_connection_details(self, name):
223229
@staticmethod
224230
def _extract_pyarrow_schema(dataframe, preserve_index):
225231
cols = []
232+
cols_dtypes = {}
226233
schema = []
234+
227235
for name, dtype in dataframe.dtypes.to_dict().items():
228236
dtype = str(dtype)
229237
if str(dtype) == "Int64":
230-
schema.append((name, "int64"))
238+
cols_dtypes[name] = "int64"
231239
else:
232240
cols.append(name)
233241

234-
# Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
235-
schema += [(str(x.name), str(x.type))
236-
for x in pyarrow.Schema.from_pandas(
237-
df=dataframe[cols], preserve_index=preserve_index)]
242+
for field in pyarrow.Schema.from_pandas(df=dataframe[cols],
243+
preserve_index=preserve_index):
244+
name = str(field.name)
245+
dtype = str(field.type)
246+
cols_dtypes[name] = dtype
247+
if name not in dataframe.columns:
248+
schema.append((name, dtype))
249+
250+
schema += [(name, cols_dtypes[name]) for name in dataframe.columns]
238251
logger.debug(f"schema: {schema}")
239252
return schema
240253

@@ -256,7 +269,8 @@ def _build_schema(dataframe, partition_cols, preserve_index):
256269
else:
257270
schema_built.append((name, athena_type))
258271

259-
partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols]
272+
partition_cols_schema_built = [(name, partition_cols_types[name])
273+
for name in partition_cols]
260274

261275
logger.debug(f"schema_built:\n{schema_built}")
262276
logger.debug(
@@ -270,17 +284,40 @@ def _parse_table_name(path):
270284
return path.rpartition("/")[2]
271285

272286
@staticmethod
273-
def csv_table_definition(table, partition_cols_schema, schema, path, extra_args):
274-
sep = extra_args["sep"] if "sep" in extra_args else ","
287+
def csv_table_definition(table, partition_cols_schema, schema, path,
288+
extra_args):
275289
if not partition_cols_schema:
276290
partition_cols_schema = []
291+
sep = extra_args["sep"] if "sep" in extra_args else ","
292+
serde = extra_args.get("serde")
293+
if serde == "OpenCSVSerDe":
294+
serde_fullname = "org.apache.hadoop.hive.serde2.OpenCSVSerde"
295+
param = {
296+
"separatorChar": sep,
297+
"quoteChar": "\"",
298+
"escapeChar": "\\",
299+
}
300+
refined_par_schema = [(name, "string")
301+
for name, dtype in partition_cols_schema]
302+
refined_schema = [(name, "string") for name, dtype in schema]
303+
elif serde == "LazySimpleSerDe":
304+
serde_fullname = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
305+
param = {"field.delim": sep, "escape.delim": "\\"}
306+
dtypes_allowed = ["int", "bigint", "float", "double"]
307+
refined_par_schema = [(name, dtype) if dtype in dtypes_allowed else
308+
(name, "string")
309+
for name, dtype in partition_cols_schema]
310+
refined_schema = [(name, dtype) if dtype in dtypes_allowed else
311+
(name, "string") for name, dtype in schema]
312+
else:
313+
raise InvalidSerDe(f"{serde} in not in the valid SerDe list.")
277314
return {
278315
"Name":
279316
table,
280317
"PartitionKeys": [{
281318
"Name": x[0],
282319
"Type": x[1]
283-
} for x in partition_cols_schema],
320+
} for x in refined_par_schema],
284321
"TableType":
285322
"EXTERNAL_TABLE",
286323
"Parameters": {
@@ -295,54 +332,61 @@ def csv_table_definition(table, partition_cols_schema, schema, path, extra_args)
295332
"Columns": [{
296333
"Name": x[0],
297334
"Type": x[1]
298-
} for x in schema],
335+
} for x in refined_schema],
299336
"Location": path,
300337
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
301338
"OutputFormat":
302339
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
303340
"Compressed": False,
304341
"NumberOfBuckets": -1,
305342
"SerdeInfo": {
306-
"Parameters": {
307-
"field.delim": sep
308-
},
309-
"SerializationLibrary":
310-
"org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
343+
"Parameters": param,
344+
"SerializationLibrary": serde_fullname,
311345
},
312346
"StoredAsSubDirectories": False,
313347
"SortColumns": [],
314348
"Parameters": {
315349
"classification": "csv",
316350
"compressionType": "none",
317351
"typeOfData": "file",
318-
"delimiter": ",",
352+
"delimiter": sep,
319353
"columnsOrdered": "true",
320354
"areColumnsQuoted": "false",
321355
},
322356
},
323357
}
324358

325359
@staticmethod
326-
def csv_partition_definition(partition):
360+
def csv_partition_definition(partition, extra_args):
361+
sep = extra_args["sep"] if "sep" in extra_args else ","
362+
serde = extra_args.get("serde")
363+
if serde == "OpenCSVSerDe":
364+
serde_fullname = "org.apache.hadoop.hive.serde2.OpenCSVSerde"
365+
param = {
366+
"separatorChar": sep,
367+
"quoteChar": "\"",
368+
"escapeChar": "\\",
369+
}
370+
elif serde == "LazySimpleSerDe":
371+
serde_fullname = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
372+
param = {"field.delim": sep, "escape.delim": "\\"}
373+
else:
374+
raise InvalidSerDe(f"{serde} in not in the valid SerDe list.")
327375
return {
328376
"StorageDescriptor": {
329377
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
330378
"Location": partition[0],
331379
"SerdeInfo": {
332-
"Parameters": {
333-
"field.delim": ","
334-
},
335-
"SerializationLibrary":
336-
"org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
380+
"Parameters": param,
381+
"SerializationLibrary": serde_fullname,
337382
},
338383
"StoredAsSubDirectories": False,
339384
},
340385
"Values": partition[1],
341386
}
342387

343388
@staticmethod
344-
def parquet_table_definition(table, partition_cols_schema,
345-
schema, path):
389+
def parquet_table_definition(table, partition_cols_schema, schema, path):
346390
if not partition_cols_schema:
347391
partition_cols_schema = []
348392
return {

awswrangler/pandas.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyarrow import parquet
1212

1313
from awswrangler.exceptions import UnsupportedWriteMode, UnsupportedFileFormat,\
14-
AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe
14+
AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe, InvalidSerDe
1515
from awswrangler.utils import calculate_bounders
1616
from awswrangler import s3
1717

@@ -26,6 +26,9 @@ def _get_bounders(dataframe, num_partitions):
2626

2727

2828
class Pandas:
29+
30+
VALID_CSV_SERDES = ["OpenCSVSerDe", "LazySimpleSerDe"]
31+
2932
def __init__(self, session):
3033
self._session = session
3134

@@ -427,15 +430,17 @@ def read_sql_athena(self,
427430
parse_dates=parse_timestamps,
428431
quoting=csv.QUOTE_ALL,
429432
max_result_size=max_result_size)
430-
for col in parse_dates:
431-
ret[col] = ret[col].dt.date
433+
if len(ret.index) > 0:
434+
for col in parse_dates:
435+
ret[col] = ret[col].dt.date
432436
return ret
433437

434438
def to_csv(
435439
self,
436440
dataframe,
437441
path,
438442
sep=",",
443+
serde="OpenCSVSerDe",
439444
database=None,
440445
table=None,
441446
partition_cols=None,
@@ -451,6 +456,7 @@ def to_csv(
451456
:param dataframe: Pandas Dataframe
452457
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/
453458
:param sep: Same as pandas.to_csv()
459+
:param serde: SerDe library name (e.g. OpenCSVSerDe, LazySimpleSerDe)
454460
:param database: AWS Glue Database name
455461
:param table: AWS Glue table name
456462
:param partition_cols: List of columns names that will be partitions on S3
@@ -460,7 +466,11 @@ def to_csv(
460466
:param procs_io_bound: Number of cores used for I/O bound tasks
461467
:return: List of objects written on S3
462468
"""
463-
extra_args = {"sep": sep}
469+
if serde not in Pandas.VALID_CSV_SERDES:
470+
raise InvalidSerDe(
471+
f"{serde} in not in the valid SerDe list ({Pandas.VALID_CSV_SERDES})"
472+
)
473+
extra_args = {"sep": sep, "serde": serde}
464474
return self.to_s3(dataframe=dataframe,
465475
path=path,
466476
file_format="csv",
@@ -745,8 +755,17 @@ def write_csv_dataframe(dataframe,
745755
fs,
746756
extra_args=None):
747757
csv_extra_args = {}
748-
if "sep" in extra_args:
749-
csv_extra_args["sep"] = extra_args["sep"]
758+
sep = extra_args.get("sep")
759+
if sep is not None:
760+
csv_extra_args["sep"] = sep
761+
serde = extra_args.get("serde")
762+
if serde is not None:
763+
if serde == "OpenCSVSerDe":
764+
csv_extra_args["quoting"] = csv.QUOTE_ALL
765+
csv_extra_args["escapechar"] = "\\"
766+
elif serde == "LazySimpleSerDe":
767+
csv_extra_args["quoting"] = csv.QUOTE_NONE
768+
csv_extra_args["escapechar"] = "\\"
750769
csv_buffer = bytes(
751770
dataframe.to_csv(None,
752771
header=False,

0 commit comments

Comments
 (0)