11
11
12
12
import numpy as np
13
13
14
- from pandas ._libs import writers as libwriters
14
+ from pandas ._libs import writers as libwriters , lib
15
15
from pandas ._typing import FilePathOrBuffer
16
16
17
17
from pandas .core .dtypes .generic import (
30
30
)
31
31
32
32
33
+ class EncodingConflictWarning (Warning ):
34
+ pass
35
+
36
+
37
+ encoding_conflict_doc = """
38
+ the encoding scheme: [%s] with which the the existing file object is opened \
39
+ conflicted with the encoding scheme: [%s] mentioned in the .to_csv method. \
40
+ Will be using encoding scheme mentioned by the file object that is [%s].
41
+ """
42
+
43
+
44
+ def _mismatch_encoding (encoding , path_or_buf_encoding ):
45
+ if encoding is None or path_or_buf_encoding is None :
46
+ return False
47
+ return encoding != path_or_buf_encoding
48
+
49
+
33
50
class CSVFormatter :
34
51
def __init__ (
35
52
self ,
@@ -44,6 +61,7 @@ def __init__(
44
61
index_label : Optional [Union [bool , Hashable , Sequence [Hashable ]]] = None ,
45
62
mode : str = "w" ,
46
63
encoding : Optional [str ] = None ,
64
+ bytes_encoding : Optional [str ] = None ,
47
65
errors : str = "strict" ,
48
66
compression : Union [str , Mapping [str , str ], None ] = "infer" ,
49
67
quoting : Optional [int ] = None ,
@@ -75,12 +93,32 @@ def __init__(
75
93
self .index = index
76
94
self .index_label = index_label
77
95
self .mode = mode
78
- if encoding is None :
79
- encoding = "utf-8"
96
+
97
+ if hasattr (self .path_or_buf , "encoding" ):
98
+ if _mismatch_encoding (encoding , self .path_or_buf .encoding ):
99
+ ws = encoding_conflict_doc % (
100
+ self .path_or_buf .encoding ,
101
+ encoding ,
102
+ self .path_or_buf .encoding ,
103
+ )
104
+ warnings .warn (ws , EncodingConflictWarning , stacklevel = 2 )
105
+ if self .path_or_buf .encoding is None :
106
+ encoding = "utf-8"
107
+ else :
108
+ encoding = self .path_or_buf .encoding
109
+ else :
110
+ if encoding is None :
111
+ encoding = "utf-8"
112
+
80
113
self .encoding = encoding
81
114
self .errors = errors
82
115
self .compression = infer_compression (self .path_or_buf , compression )
83
116
117
+ if bytes_encoding is None :
118
+ bytes_encoding = self .encoding
119
+
120
+ self .bytes_encoding = bytes_encoding
121
+
84
122
if quoting is None :
85
123
quoting = csvlib .QUOTE_MINIMAL
86
124
self .quoting = quoting
@@ -108,6 +146,7 @@ def __init__(
108
146
if isinstance (cols , ABCIndexClass ):
109
147
cols = cols .to_native_types (
110
148
na_rep = na_rep ,
149
+ bytes_encoding = bytes_encoding ,
111
150
float_format = float_format ,
112
151
date_format = date_format ,
113
152
quoting = self .quoting ,
@@ -122,6 +161,7 @@ def __init__(
122
161
if isinstance (cols , ABCIndexClass ):
123
162
cols = cols .to_native_types (
124
163
na_rep = na_rep ,
164
+ bytes_encoding = bytes_encoding ,
125
165
float_format = float_format ,
126
166
date_format = date_format ,
127
167
quoting = self .quoting ,
@@ -278,6 +318,8 @@ def _save_header(self):
278
318
else :
279
319
encoded_labels = []
280
320
321
+ self ._bytes_to_str (encoded_labels )
322
+
281
323
if not has_mi_columns or has_aliases :
282
324
encoded_labels += list (write_cols )
283
325
writer .writerow (encoded_labels )
@@ -300,6 +342,7 @@ def _save_header(self):
300
342
col_line .extend (["" ] * (len (index_label ) - 1 ))
301
343
302
344
col_line .extend (columns ._get_level_values (i ))
345
+ self ._bytes_to_str (col_line )
303
346
304
347
writer .writerow (col_line )
305
348
@@ -340,6 +383,7 @@ def _save_chunk(self, start_i: int, end_i: int) -> None:
340
383
b = blocks [i ]
341
384
d = b .to_native_types (
342
385
na_rep = self .na_rep ,
386
+ bytes_encoding = self .bytes_encoding ,
343
387
float_format = self .float_format ,
344
388
decimal = self .decimal ,
345
389
date_format = self .date_format ,
@@ -353,10 +397,23 @@ def _save_chunk(self, start_i: int, end_i: int) -> None:
353
397
ix = data_index .to_native_types (
354
398
slicer = slicer ,
355
399
na_rep = self .na_rep ,
400
+ bytes_encoding = self .bytes_encoding ,
356
401
float_format = self .float_format ,
357
402
decimal = self .decimal ,
358
403
date_format = self .date_format ,
359
404
quoting = self .quoting ,
360
405
)
361
406
362
407
libwriters .write_csv_rows (self .data , ix , self .nlevels , self .cols , self .writer )
408
+
409
+ def _bytes_to_str (self , values ):
410
+ """If all the values are bytes, then modify values list by decoding
411
+ bytes to str."""
412
+ np_values = np .array (values , dtype = object )
413
+ is_all_bytes = lib .is_bytes_array (np_values )
414
+ is_any_bytes = lib .is_any_bytes_in_array (np_values )
415
+ if is_any_bytes and not is_all_bytes :
416
+ raise ValueError ("Cannot mix types" )
417
+ if self .bytes_encoding is not None and is_all_bytes :
418
+ for i , value in enumerate (values ):
419
+ values [i ] = value .decode (self .bytes_encoding )
0 commit comments