@@ -133,12 +133,14 @@ def metadata_to_glue(self,
133
133
partition_cols = None ,
134
134
preserve_index = True ,
135
135
mode = "append" ,
136
+ compression = None ,
136
137
cast_columns = None ,
137
138
extra_args = None ):
138
139
schema , partition_cols_schema = Glue ._build_schema (
139
140
dataframe = dataframe ,
140
141
partition_cols = partition_cols ,
141
- preserve_index = preserve_index )
142
+ preserve_index = preserve_index ,
143
+ cast_columns = cast_columns )
142
144
table = table if table else Glue ._parse_table_name (path )
143
145
table = table .lower ().replace ("." , "_" )
144
146
if mode == "overwrite" :
@@ -151,6 +153,7 @@ def metadata_to_glue(self,
151
153
partition_cols_schema = partition_cols_schema ,
152
154
path = path ,
153
155
file_format = file_format ,
156
+ compression = compression ,
154
157
extra_args = extra_args )
155
158
if partition_cols :
156
159
partitions_tuples = Glue ._parse_partitions_tuples (
@@ -159,6 +162,7 @@ def metadata_to_glue(self,
159
162
table = table ,
160
163
partition_paths = partitions_tuples ,
161
164
file_format = file_format ,
165
+ compression = compression ,
162
166
extra_args = extra_args )
163
167
164
168
def delete_table_if_exists (self , database , table ):
@@ -180,16 +184,18 @@ def create_table(self,
180
184
schema ,
181
185
path ,
182
186
file_format ,
187
+ compression ,
183
188
partition_cols_schema = None ,
184
189
extra_args = None ):
185
190
if file_format == "parquet" :
186
191
table_input = Glue .parquet_table_definition (
187
- table , partition_cols_schema , schema , path )
192
+ table , partition_cols_schema , schema , path , compression )
188
193
elif file_format == "csv" :
189
194
table_input = Glue .csv_table_definition (table ,
190
195
partition_cols_schema ,
191
196
schema ,
192
197
path ,
198
+ compression ,
193
199
extra_args = extra_args )
194
200
else :
195
201
raise UnsupportedFileFormat (file_format )
@@ -227,15 +233,21 @@ def get_connection_details(self, name):
227
233
Name = name , HidePassword = False )["Connection" ]
228
234
229
235
@staticmethod
230
- def _extract_pyarrow_schema (dataframe , preserve_index ):
236
+ def _extract_pyarrow_schema (dataframe , preserve_index , cast_columns = None ):
231
237
cols = []
232
238
cols_dtypes = {}
233
239
schema = []
234
240
241
+ casted = []
242
+ if cast_columns is not None :
243
+ casted = cast_columns .keys ()
244
+
235
245
for name , dtype in dataframe .dtypes .to_dict ().items ():
236
246
dtype = str (dtype )
237
- if str ( dtype ) == "Int64" :
247
+ if dtype == "Int64" :
238
248
cols_dtypes [name ] = "int64"
249
+ elif name in casted :
250
+ cols_dtypes [name ] = cast_columns [name ]
239
251
else :
240
252
cols .append (name )
241
253
@@ -252,13 +264,18 @@ def _extract_pyarrow_schema(dataframe, preserve_index):
252
264
return schema
253
265
254
266
@staticmethod
255
- def _build_schema (dataframe , partition_cols , preserve_index ):
267
+ def _build_schema (dataframe ,
268
+ partition_cols ,
269
+ preserve_index ,
270
+ cast_columns = {}):
256
271
logger .debug (f"dataframe.dtypes:\n { dataframe .dtypes } " )
257
272
if not partition_cols :
258
273
partition_cols = []
259
274
260
275
pyarrow_schema = Glue ._extract_pyarrow_schema (
261
- dataframe = dataframe , preserve_index = preserve_index )
276
+ dataframe = dataframe ,
277
+ preserve_index = preserve_index ,
278
+ cast_columns = cast_columns )
262
279
263
280
schema_built = []
264
281
partition_cols_types = {}
@@ -285,9 +302,10 @@ def _parse_table_name(path):
285
302
286
303
@staticmethod
287
304
def csv_table_definition (table , partition_cols_schema , schema , path ,
288
- extra_args ):
305
+ compression , extra_args ):
289
306
if not partition_cols_schema :
290
307
partition_cols_schema = []
308
+ compressed = False if compression is None else True
291
309
sep = extra_args ["sep" ] if "sep" in extra_args else ","
292
310
serde = extra_args .get ("serde" )
293
311
if serde == "OpenCSVSerDe" :
@@ -322,7 +340,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
322
340
"EXTERNAL_TABLE" ,
323
341
"Parameters" : {
324
342
"classification" : "csv" ,
325
- "compressionType" : "none" ,
343
+ "compressionType" : str ( compression ). lower () ,
326
344
"typeOfData" : "file" ,
327
345
"delimiter" : sep ,
328
346
"columnsOrdered" : "true" ,
@@ -337,7 +355,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
337
355
"InputFormat" : "org.apache.hadoop.mapred.TextInputFormat" ,
338
356
"OutputFormat" :
339
357
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat" ,
340
- "Compressed" : False ,
358
+ "Compressed" : True ,
341
359
"NumberOfBuckets" : - 1 ,
342
360
"SerdeInfo" : {
343
361
"Parameters" : param ,
@@ -347,7 +365,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path,
347
365
"SortColumns" : [],
348
366
"Parameters" : {
349
367
"classification" : "csv" ,
350
- "compressionType" : "none" ,
368
+ "compressionType" : str ( compression ). lower () ,
351
369
"typeOfData" : "file" ,
352
370
"delimiter" : sep ,
353
371
"columnsOrdered" : "true" ,
@@ -386,9 +404,11 @@ def csv_partition_definition(partition, extra_args):
386
404
}
387
405
388
406
@staticmethod
389
- def parquet_table_definition (table , partition_cols_schema , schema , path ):
407
+ def parquet_table_definition (table , partition_cols_schema , schema , path ,
408
+ compression ):
390
409
if not partition_cols_schema :
391
410
partition_cols_schema = []
411
+ compressed = False if compression is None else True
392
412
return {
393
413
"Name" :
394
414
table ,
@@ -400,7 +420,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
400
420
"EXTERNAL_TABLE" ,
401
421
"Parameters" : {
402
422
"classification" : "parquet" ,
403
- "compressionType" : "none" ,
423
+ "compressionType" : str ( compression ). lower () ,
404
424
"typeOfData" : "file" ,
405
425
},
406
426
"StorageDescriptor" : {
@@ -413,7 +433,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
413
433
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat" ,
414
434
"OutputFormat" :
415
435
"org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat" ,
416
- "Compressed" : False ,
436
+ "Compressed" : compressed ,
417
437
"NumberOfBuckets" : - 1 ,
418
438
"SerdeInfo" : {
419
439
"SerializationLibrary" :
@@ -427,7 +447,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path):
427
447
"Parameters" : {
428
448
"CrawlerSchemaDeserializerVersion" : "1.0" ,
429
449
"classification" : "parquet" ,
430
- "compressionType" : "none" ,
450
+ "compressionType" : str ( compression ). lower () ,
431
451
"typeOfData" : "file" ,
432
452
},
433
453
},
0 commit comments