Skip to content

Commit 6af5c7c

Browse files
authored
Merge pull request #28 from awslabs/compression
Add compression for Pandas.to_parquet
2 parents ae3fc5a + 23e1619 commit 6af5c7c

File tree

4 files changed

+113
-18
lines changed

4 files changed

+113
-18
lines changed

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,7 @@ class InvalidSerDe(Exception):
7272

7373
class ApiError(Exception):
7474
pass
75+
76+
77+
class InvalidCompression(Exception):
78+
pass

awswrangler/glue.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,14 @@ def metadata_to_glue(self,
133133
partition_cols=None,
134134
preserve_index=True,
135135
mode="append",
136+
compression=None,
136137
cast_columns=None,
137138
extra_args=None):
138139
schema, partition_cols_schema = Glue._build_schema(
139140
dataframe=dataframe,
140141
partition_cols=partition_cols,
141-
preserve_index=preserve_index)
142+
preserve_index=preserve_index,
143+
cast_columns=cast_columns)
142144
table = table if table else Glue._parse_table_name(path)
143145
table = table.lower().replace(".", "_")
144146
if mode == "overwrite":
@@ -151,6 +153,7 @@ def metadata_to_glue(self,
151153
partition_cols_schema=partition_cols_schema,
152154
path=path,
153155
file_format=file_format,
156+
compression=compression,
154157
extra_args=extra_args)
155158
if partition_cols:
156159
partitions_tuples = Glue._parse_partitions_tuples(
@@ -159,6 +162,7 @@ def metadata_to_glue(self,
159162
table=table,
160163
partition_paths=partitions_tuples,
161164
file_format=file_format,
165+
compression=compression,
162166
extra_args=extra_args)
163167

164168
def delete_table_if_exists(self, database, table):
@@ -180,16 +184,18 @@ def create_table(self,
180184
schema,
181185
path,
182186
file_format,
187+
compression,
183188
partition_cols_schema=None,
184189
extra_args=None):
185190
if file_format == "parquet":
186191
table_input = Glue.parquet_table_definition(
187-
table, partition_cols_schema, schema, path)
192+
table, partition_cols_schema, schema, path, compression)
188193
elif file_format == "csv":
189194
table_input = Glue.csv_table_definition(table,
190195
partition_cols_schema,
191196
schema,
192197
path,
198+
compression,
193199
extra_args=extra_args)
194200
else:
195201
raise UnsupportedFileFormat(file_format)
@@ -227,15 +233,21 @@ def get_connection_details(self, name):
227233
Name=name, HidePassword=False)["Connection"]
228234

229235
@staticmethod
230-
def _extract_pyarrow_schema(dataframe, preserve_index):
236+
def _extract_pyarrow_schema(dataframe, preserve_index, cast_columns=None):
231237
cols = []
232238
cols_dtypes = {}
233239
schema = []
234240

241+
casted = []
242+
if cast_columns is not None:
243+
casted = cast_columns.keys()
244+
235245
for name, dtype in dataframe.dtypes.to_dict().items():
236246
dtype = str(dtype)
237-
if str(dtype) == "Int64":
247+
if dtype == "Int64":
238248
cols_dtypes[name] = "int64"
249+
elif name in casted:
250+
cols_dtypes[name] = cast_columns[name]
239251
else:
240252
cols.append(name)
241253

@@ -252,13 +264,18 @@ def _extract_pyarrow_schema(dataframe, preserve_index):
252264
return schema
253265

254266
@staticmethod
255-
def _build_schema(dataframe, partition_cols, preserve_index):
267+
def _build_schema(dataframe,
268+
partition_cols,
269+
preserve_index,
270+
cast_columns={}):
256271
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
257272
if not partition_cols:
258273
partition_cols = []
259274

260275
pyarrow_schema = Glue._extract_pyarrow_schema(
261-
dataframe=dataframe, preserve_index=preserve_index)
276+
dataframe=dataframe,
277+
preserve_index=preserve_index,
278+
cast_columns=cast_columns)
262279

263280
schema_built = []
264281
partition_cols_types = {}
@@ -285,9 +302,10 @@ def _parse_table_name(path):
285302

286303
@staticmethod
287304
def csv_table_definition(table, partition_cols_schema, schema, path,
288-
extra_args):
305+
compression, extra_args):
289306
if not partition_cols_schema:
290307
partition_cols_schema = []
308+
compressed = False if compression is None else True
291309
sep = extra_args["sep"] if "sep" in extra_args else ","
292310
serde = extra_args.get("serde")
293311
if serde == "OpenCSVSerDe":
@@ -322,7 +340,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
322340
"EXTERNAL_TABLE",
323341
"Parameters": {
324342
"classification": "csv",
325-
"compressionType": "none",
343+
"compressionType": str(compression).lower(),
326344
"typeOfData": "file",
327345
"delimiter": sep,
328346
"columnsOrdered": "true",
@@ -337,7 +355,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
337355
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
338356
"OutputFormat":
339357
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
340-
"Compressed": False,
358+
"Compressed": True,
341359
"NumberOfBuckets": -1,
342360
"SerdeInfo": {
343361
"Parameters": param,
@@ -347,7 +365,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
347365
"SortColumns": [],
348366
"Parameters": {
349367
"classification": "csv",
350-
"compressionType": "none",
368+
"compressionType": str(compression).lower(),
351369
"typeOfData": "file",
352370
"delimiter": sep,
353371
"columnsOrdered": "true",
@@ -386,9 +404,11 @@ def csv_partition_definition(partition, extra_args):
386404
}
387405

388406
@staticmethod
389-
def parquet_table_definition(table, partition_cols_schema, schema, path):
407+
def parquet_table_definition(table, partition_cols_schema, schema, path,
408+
compression):
390409
if not partition_cols_schema:
391410
partition_cols_schema = []
411+
compressed = False if compression is None else True
392412
return {
393413
"Name":
394414
table,
@@ -400,7 +420,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
400420
"EXTERNAL_TABLE",
401421
"Parameters": {
402422
"classification": "parquet",
403-
"compressionType": "none",
423+
"compressionType": str(compression).lower(),
404424
"typeOfData": "file",
405425
},
406426
"StorageDescriptor": {
@@ -413,7 +433,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
413433
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat",
414434
"OutputFormat":
415435
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat",
416-
"Compressed": False,
436+
"Compressed": compressed,
417437
"NumberOfBuckets": -1,
418438
"SerdeInfo": {
419439
"SerializationLibrary":
@@ -427,7 +447,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
427447
"Parameters": {
428448
"CrawlerSchemaDeserializerVersion": "1.0",
429449
"classification": "parquet",
430-
"compressionType": "none",
450+
"compressionType": str(compression).lower(),
431451
"typeOfData": "file",
432452
},
433453
},

0 commit comments

Comments
 (0)