5
5
6
6
import pyarrow
7
7
8
- from awswrangler .exceptions import UnsupportedType , UnsupportedFileFormat
8
+ from awswrangler .exceptions import UnsupportedType , UnsupportedFileFormat , InvalidSerDe , ApiError
9
9
10
10
logger = logging .getLogger (__name__ )
11
11
@@ -155,12 +155,11 @@ def metadata_to_glue(self,
155
155
if partition_cols :
156
156
partitions_tuples = Glue ._parse_partitions_tuples (
157
157
objects_paths = objects_paths , partition_cols = partition_cols )
158
- self .add_partitions (
159
- database = database ,
160
- table = table ,
161
- partition_paths = partitions_tuples ,
162
- file_format = file_format ,
163
- )
158
+ self .add_partitions (database = database ,
159
+ table = table ,
160
+ partition_paths = partitions_tuples ,
161
+ file_format = file_format ,
162
+ extra_args = extra_args )
164
163
165
164
def delete_table_if_exists (self , database , table ):
166
165
try :
@@ -184,7 +183,8 @@ def create_table(self,
184
183
partition_cols_schema = None ,
185
184
extra_args = None ):
186
185
if file_format == "parquet" :
187
- table_input = Glue .parquet_table_definition (table , partition_cols_schema , schema , path )
186
+ table_input = Glue .parquet_table_definition (
187
+ table , partition_cols_schema , schema , path )
188
188
elif file_format == "csv" :
189
189
table_input = Glue .csv_table_definition (table ,
190
190
partition_cols_schema ,
@@ -196,25 +196,31 @@ def create_table(self,
196
196
self ._client_glue .create_table (DatabaseName = database ,
197
197
TableInput = table_input )
198
198
199
- def add_partitions (self , database , table , partition_paths , file_format ):
199
+ def add_partitions (self , database , table , partition_paths , file_format ,
200
+ extra_args ):
200
201
if not partition_paths :
201
202
return None
202
203
partitions = list ()
203
204
for partition in partition_paths :
204
205
if file_format == "parquet" :
205
- partition_def = Glue .parquet_partition_definition (partition )
206
+ partition_def = Glue .parquet_partition_definition (
207
+ partition = partition )
206
208
elif file_format == "csv" :
207
- partition_def = Glue .csv_partition_definition (partition )
209
+ partition_def = Glue .csv_partition_definition (
210
+ partition = partition , extra_args = extra_args )
208
211
else :
209
212
raise UnsupportedFileFormat (file_format )
210
213
partitions .append (partition_def )
211
214
pages_num = int (ceil (len (partitions ) / 100.0 ))
212
215
for _ in range (pages_num ):
213
216
page = partitions [:100 ]
214
217
del partitions [:100 ]
215
- self ._client_glue .batch_create_partition (DatabaseName = database ,
216
- TableName = table ,
217
- PartitionInputList = page )
218
+ res = self ._client_glue .batch_create_partition (
219
+ DatabaseName = database ,
220
+ TableName = table ,
221
+ PartitionInputList = page )
222
+ if len (res ["Errors" ]) > 0 :
223
+ raise ApiError (f"{ res ['Errors' ][0 ]} " )
218
224
219
225
def get_connection_details (self , name ):
220
226
return self ._client_glue .get_connection (
@@ -223,18 +229,25 @@ def get_connection_details(self, name):
223
229
@staticmethod
224
230
def _extract_pyarrow_schema (dataframe , preserve_index ):
225
231
cols = []
232
+ cols_dtypes = {}
226
233
schema = []
234
+
227
235
for name , dtype in dataframe .dtypes .to_dict ().items ():
228
236
dtype = str (dtype )
229
237
if str (dtype ) == "Int64" :
230
- schema . append (( name , "int64" ))
238
+ cols_dtypes [ name ] = "int64"
231
239
else :
232
240
cols .append (name )
233
241
234
- # Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
235
- schema += [(str (x .name ), str (x .type ))
236
- for x in pyarrow .Schema .from_pandas (
237
- df = dataframe [cols ], preserve_index = preserve_index )]
242
+ for field in pyarrow .Schema .from_pandas (df = dataframe [cols ],
243
+ preserve_index = preserve_index ):
244
+ name = str (field .name )
245
+ dtype = str (field .type )
246
+ cols_dtypes [name ] = dtype
247
+ if name not in dataframe .columns :
248
+ schema .append ((name , dtype ))
249
+
250
+ schema += [(name , cols_dtypes [name ]) for name in dataframe .columns ]
238
251
logger .debug (f"schema: { schema } " )
239
252
return schema
240
253
@@ -256,7 +269,8 @@ def _build_schema(dataframe, partition_cols, preserve_index):
256
269
else :
257
270
schema_built .append ((name , athena_type ))
258
271
259
- partition_cols_schema_built = [(name , partition_cols_types [name ]) for name in partition_cols ]
272
+ partition_cols_schema_built = [(name , partition_cols_types [name ])
273
+ for name in partition_cols ]
260
274
261
275
logger .debug (f"schema_built:\n { schema_built } " )
262
276
logger .debug (
@@ -270,17 +284,40 @@ def _parse_table_name(path):
270
284
return path .rpartition ("/" )[2 ]
271
285
272
286
@staticmethod
273
- def csv_table_definition (table , partition_cols_schema , schema , path , extra_args ):
274
- sep = extra_args [ "sep" ] if "sep" in extra_args else ","
287
+ def csv_table_definition (table , partition_cols_schema , schema , path ,
288
+ extra_args ):
275
289
if not partition_cols_schema :
276
290
partition_cols_schema = []
291
+ sep = extra_args ["sep" ] if "sep" in extra_args else ","
292
+ serde = extra_args .get ("serde" )
293
+ if serde == "OpenCSVSerDe" :
294
+ serde_fullname = "org.apache.hadoop.hive.serde2.OpenCSVSerde"
295
+ param = {
296
+ "separatorChar" : sep ,
297
+ "quoteChar" : "\" " ,
298
+ "escapeChar" : "\\ " ,
299
+ }
300
+ refined_par_schema = [(name , "string" )
301
+ for name , dtype in partition_cols_schema ]
302
+ refined_schema = [(name , "string" ) for name , dtype in schema ]
303
+ elif serde == "LazySimpleSerDe" :
304
+ serde_fullname = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
305
+ param = {"field.delim" : sep , "escape.delim" : "\\ " }
306
+ dtypes_allowed = ["int" , "bigint" , "float" , "double" ]
307
+ refined_par_schema = [(name , dtype ) if dtype in dtypes_allowed else
308
+ (name , "string" )
309
+ for name , dtype in partition_cols_schema ]
310
+ refined_schema = [(name , dtype ) if dtype in dtypes_allowed else
311
+ (name , "string" ) for name , dtype in schema ]
312
+ else :
313
+ raise InvalidSerDe (f"{ serde } in not in the valid SerDe list." )
277
314
return {
278
315
"Name" :
279
316
table ,
280
317
"PartitionKeys" : [{
281
318
"Name" : x [0 ],
282
319
"Type" : x [1 ]
283
- } for x in partition_cols_schema ],
320
+ } for x in refined_par_schema ],
284
321
"TableType" :
285
322
"EXTERNAL_TABLE" ,
286
323
"Parameters" : {
@@ -295,54 +332,61 @@ def csv_table_definition(table, partition_cols_schema, schema, path, extra_args)
295
332
"Columns" : [{
296
333
"Name" : x [0 ],
297
334
"Type" : x [1 ]
298
- } for x in schema ],
335
+ } for x in refined_schema ],
299
336
"Location" : path ,
300
337
"InputFormat" : "org.apache.hadoop.mapred.TextInputFormat" ,
301
338
"OutputFormat" :
302
339
"org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat" ,
303
340
"Compressed" : False ,
304
341
"NumberOfBuckets" : - 1 ,
305
342
"SerdeInfo" : {
306
- "Parameters" : {
307
- "field.delim" : sep
308
- },
309
- "SerializationLibrary" :
310
- "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" ,
343
+ "Parameters" : param ,
344
+ "SerializationLibrary" : serde_fullname ,
311
345
},
312
346
"StoredAsSubDirectories" : False ,
313
347
"SortColumns" : [],
314
348
"Parameters" : {
315
349
"classification" : "csv" ,
316
350
"compressionType" : "none" ,
317
351
"typeOfData" : "file" ,
318
- "delimiter" : "," ,
352
+ "delimiter" : sep ,
319
353
"columnsOrdered" : "true" ,
320
354
"areColumnsQuoted" : "false" ,
321
355
},
322
356
},
323
357
}
324
358
325
359
@staticmethod
326
- def csv_partition_definition (partition ):
360
+ def csv_partition_definition (partition , extra_args ):
361
+ sep = extra_args ["sep" ] if "sep" in extra_args else ","
362
+ serde = extra_args .get ("serde" )
363
+ if serde == "OpenCSVSerDe" :
364
+ serde_fullname = "org.apache.hadoop.hive.serde2.OpenCSVSerde"
365
+ param = {
366
+ "separatorChar" : sep ,
367
+ "quoteChar" : "\" " ,
368
+ "escapeChar" : "\\ " ,
369
+ }
370
+ elif serde == "LazySimpleSerDe" :
371
+ serde_fullname = "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
372
+ param = {"field.delim" : sep , "escape.delim" : "\\ " }
373
+ else :
374
+ raise InvalidSerDe (f"{ serde } in not in the valid SerDe list." )
327
375
return {
328
376
"StorageDescriptor" : {
329
377
"InputFormat" : "org.apache.hadoop.mapred.TextInputFormat" ,
330
378
"Location" : partition [0 ],
331
379
"SerdeInfo" : {
332
- "Parameters" : {
333
- "field.delim" : ","
334
- },
335
- "SerializationLibrary" :
336
- "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" ,
380
+ "Parameters" : param ,
381
+ "SerializationLibrary" : serde_fullname ,
337
382
},
338
383
"StoredAsSubDirectories" : False ,
339
384
},
340
385
"Values" : partition [1 ],
341
386
}
342
387
343
388
@staticmethod
344
- def parquet_table_definition (table , partition_cols_schema ,
345
- schema , path ):
389
+ def parquet_table_definition (table , partition_cols_schema , schema , path ):
346
390
if not partition_cols_schema :
347
391
partition_cols_schema = []
348
392
return {
0 commit comments