Skip to content

Commit 2d3d035

Browse files
committed
BUG: Fix b' prefix for bytes in to_csv() (pandas-dev#9712)
Add a new optional parameter named bytes_encoding to allow a specific encoding scheme to be used to decode the bytes.
1 parent 506eb54 commit 2d3d035

File tree

7 files changed

+194
-7
lines changed

7 files changed

+194
-7
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,7 @@ I/O
998998
- Bug in :meth:`~SQLDatabase.execute` was raising a ``ProgrammingError`` for some DB-API drivers when the SQL statement contained the `%` character and no parameters were present (:issue:`34211`)
999999
- Bug in :meth:`~pandas.io.stata.StataReader` which resulted in categorical variables with difference dtypes when reading data using an iterator. (:issue:`31544`)
10001000
- :meth:`HDFStore.keys` has now an optional `include` parameter that allows the retrieval of all native HDF5 table names (:issue:`29916`)
1001+
- Bug in :meth:`to_csv` which emitted b'' around bytes. It now has an optional `bytes_encoding` parameter that allows to pass a specific encoding scheme according to which the bytes are decoded. (:issue:`9712`)
10011002

10021003
Plotting
10031004
^^^^^^^^

pandas/_libs/lib.pyx

+17-1
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,17 @@ cdef class Validator:
15581558
else:
15591559
return False
15601560

1561+
cdef bint any(self, ndarray values) except -1:
1562+
if not self.n:
1563+
return False
1564+
cdef:
1565+
Py_ssize_t i
1566+
Py_ssize_t n = self.n
1567+
for i in range(n):
1568+
if self.is_valid(values[i]):
1569+
return True
1570+
return False
1571+
15611572
@cython.wraparound(False)
15621573
@cython.boundscheck(False)
15631574
cdef bint _validate(self, ndarray values) except -1:
@@ -1710,12 +1721,17 @@ cdef class BytesValidator(Validator):
17101721
return issubclass(self.dtype.type, np.bytes_)
17111722

17121723

1713-
cdef bint is_bytes_array(ndarray values, bint skipna=False):
1724+
cpdef bint is_bytes_array(ndarray values, bint skipna=False):
17141725
cdef:
17151726
BytesValidator validator = BytesValidator(len(values), values.dtype,
17161727
skipna=skipna)
17171728
return validator.validate(values)
17181729

1730+
cpdef bint is_any_bytes_in_array(ndarray values, bint skipna=False):
1731+
cdef:
1732+
BytesValidator validator = BytesValidator(len(values), values.dtype,
1733+
skipna=skipna)
1734+
return validator.any(values)
17191735

17201736
cdef class TemporalValidator(Validator):
17211737
cdef:

pandas/core/generic.py

+6
Original file line numberDiff line numberDiff line change
@@ -3031,6 +3031,7 @@ def to_csv(
30313031
index_label: Optional[Union[bool_t, str, Sequence[Label]]] = None,
30323032
mode: str = "w",
30333033
encoding: Optional[str] = None,
3034+
bytes_encoding: Optional[str] = None,
30343035
compression: Optional[Union[str, Mapping[str, str]]] = "infer",
30353036
quoting: Optional[int] = None,
30363037
quotechar: str = '"',
@@ -3088,6 +3089,10 @@ def to_csv(
30883089
encoding : str, optional
30893090
A string representing the encoding to use in the output file,
30903091
defaults to 'utf-8'.
3092+
bytes_encoding : str, optional
3093+
A string representing the encoding to use to decode the bytes
3094+
in the output file, defaults to using the 'encoding' parameter or the
3095+
encoding specified by the file object.
30913096
compression : str or dict, default 'infer'
30923097
If str, represents compression mode. If dict, value at 'method' is
30933098
the compression mode. Compression mode may be any of the following
@@ -3178,6 +3183,7 @@ def to_csv(
31783183
line_terminator=line_terminator,
31793184
sep=sep,
31803185
encoding=encoding,
3186+
bytes_encoding=bytes_encoding,
31813187
errors=errors,
31823188
compression=compression,
31833189
quoting=quoting,

pandas/core/indexes/base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from pandas.core.ops import get_op_result_name
7878
from pandas.core.ops.invalid import make_invalid_op
7979
from pandas.core.sorting import ensure_key_mapped
80-
from pandas.core.strings import StringMethods
80+
from pandas.core.strings import StringMethods, str_decode
8181

8282
from pandas.io.formats.printing import (
8383
PrettyDict,
@@ -954,6 +954,8 @@ def to_native_types(self, slicer=None, **kwargs):
954954
Whether or not there are quoted values in `self`
955955
3) date_format : str
956956
The format used to represent date-like values.
957+
4) bytes_encoding : str
958+
The encoding scheme to use to decode the bytes.
957959
958960
Returns
959961
-------
@@ -965,7 +967,9 @@ def to_native_types(self, slicer=None, **kwargs):
965967
values = values[slicer]
966968
return values._format_native_types(**kwargs)
967969

968-
def _format_native_types(self, na_rep="", quoting=None, **kwargs):
970+
def _format_native_types(
971+
self, na_rep="", quoting=None, bytes_encoding=None, **kwargs
972+
):
969973
"""
970974
Actually format specific types of the index.
971975
"""
@@ -976,6 +980,12 @@ def _format_native_types(self, na_rep="", quoting=None, **kwargs):
976980
values = np.array(self, dtype=object, copy=True)
977981

978982
values[mask] = na_rep
983+
is_all_bytes = lib.is_bytes_array(values, skipna=True)
984+
is_any_bytes = lib.is_any_bytes_in_array(values, skipna=True)
985+
if is_any_bytes and not is_all_bytes:
986+
raise ValueError("Cannot mix types")
987+
if bytes_encoding is not None and is_all_bytes:
988+
values = str_decode(values, bytes_encoding)
979989
return values
980990

981991
def _summary(self, name=None) -> str_t:

pandas/core/internals/blocks.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
)
8383
import pandas.core.missing as missing
8484
from pandas.core.nanops import nanpercentile
85+
from pandas.core.strings import str_decode
8586

8687
if TYPE_CHECKING:
8788
from pandas import Index
@@ -642,13 +643,24 @@ def should_store(self, value: ArrayLike) -> bool:
642643
"""
643644
return is_dtype_equal(value.dtype, self.dtype)
644645

645-
def to_native_types(self, na_rep="nan", quoting=None, **kwargs):
646+
def to_native_types(
647+
self, na_rep="nan", bytes_encoding=None, quoting=None, **kwargs
648+
):
646649
""" convert to our native types format """
647650
values = self.values
648651

649652
mask = isna(values)
650653
itemsize = writers.word_len(na_rep)
651654

655+
length = values.shape[0]
656+
for i in range(length):
657+
is_all_bytes = lib.is_bytes_array(values[i], skipna=True)
658+
is_any_bytes = lib.is_any_bytes_in_array(values[i])
659+
if is_any_bytes and not is_all_bytes:
660+
raise ValueError("Cannot mix types")
661+
if bytes_encoding is not None and is_all_bytes:
662+
values[i] = str_decode(values[i], bytes_encoding)
663+
652664
if not self.is_object and not quoting and itemsize:
653665
values = values.astype(str)
654666
if values.dtype.itemsize / np.dtype("U1").itemsize < itemsize:

pandas/io/formats/csvs.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313

14-
from pandas._libs import writers as libwriters
14+
from pandas._libs import writers as libwriters, lib
1515
from pandas._typing import FilePathOrBuffer
1616

1717
from pandas.core.dtypes.generic import (
@@ -30,6 +30,23 @@
3030
)
3131

3232

33+
class EncodingConflictWarning(Warning):
34+
pass
35+
36+
37+
encoding_conflict_doc = """
38+
the encoding scheme: [%s] with which the the existing file object is opened \
39+
conflicted with the encoding scheme: [%s] mentioned in the .to_csv method. \
40+
Will be using encoding scheme mentioned by the file object that is [%s].
41+
"""
42+
43+
44+
def _mismatch_encoding(encoding, path_or_buf_encoding):
45+
if encoding is None or path_or_buf_encoding is None:
46+
return False
47+
return encoding != path_or_buf_encoding
48+
49+
3350
class CSVFormatter:
3451
def __init__(
3552
self,
@@ -44,6 +61,7 @@ def __init__(
4461
index_label: Optional[Union[bool, Hashable, Sequence[Hashable]]] = None,
4562
mode: str = "w",
4663
encoding: Optional[str] = None,
64+
bytes_encoding: Optional[str] = None,
4765
errors: str = "strict",
4866
compression: Union[str, Mapping[str, str], None] = "infer",
4967
quoting: Optional[int] = None,
@@ -75,12 +93,32 @@ def __init__(
7593
self.index = index
7694
self.index_label = index_label
7795
self.mode = mode
78-
if encoding is None:
79-
encoding = "utf-8"
96+
97+
if hasattr(self.path_or_buf, "encoding"):
98+
if _mismatch_encoding(encoding, self.path_or_buf.encoding):
99+
ws = encoding_conflict_doc % (
100+
self.path_or_buf.encoding,
101+
encoding,
102+
self.path_or_buf.encoding,
103+
)
104+
warnings.warn(ws, EncodingConflictWarning, stacklevel=2)
105+
if self.path_or_buf.encoding is None:
106+
encoding = "utf-8"
107+
else:
108+
encoding = self.path_or_buf.encoding
109+
else:
110+
if encoding is None:
111+
encoding = "utf-8"
112+
80113
self.encoding = encoding
81114
self.errors = errors
82115
self.compression = infer_compression(self.path_or_buf, compression)
83116

117+
if bytes_encoding is None:
118+
bytes_encoding = self.encoding
119+
120+
self.bytes_encoding = bytes_encoding
121+
84122
if quoting is None:
85123
quoting = csvlib.QUOTE_MINIMAL
86124
self.quoting = quoting
@@ -108,6 +146,7 @@ def __init__(
108146
if isinstance(cols, ABCIndexClass):
109147
cols = cols.to_native_types(
110148
na_rep=na_rep,
149+
bytes_encoding=bytes_encoding,
111150
float_format=float_format,
112151
date_format=date_format,
113152
quoting=self.quoting,
@@ -122,6 +161,7 @@ def __init__(
122161
if isinstance(cols, ABCIndexClass):
123162
cols = cols.to_native_types(
124163
na_rep=na_rep,
164+
bytes_encoding=bytes_encoding,
125165
float_format=float_format,
126166
date_format=date_format,
127167
quoting=self.quoting,
@@ -278,6 +318,8 @@ def _save_header(self):
278318
else:
279319
encoded_labels = []
280320

321+
self._bytes_to_str(encoded_labels)
322+
281323
if not has_mi_columns or has_aliases:
282324
encoded_labels += list(write_cols)
283325
writer.writerow(encoded_labels)
@@ -300,6 +342,7 @@ def _save_header(self):
300342
col_line.extend([""] * (len(index_label) - 1))
301343

302344
col_line.extend(columns._get_level_values(i))
345+
self._bytes_to_str(col_line)
303346

304347
writer.writerow(col_line)
305348

@@ -340,6 +383,7 @@ def _save_chunk(self, start_i: int, end_i: int) -> None:
340383
b = blocks[i]
341384
d = b.to_native_types(
342385
na_rep=self.na_rep,
386+
bytes_encoding=self.bytes_encoding,
343387
float_format=self.float_format,
344388
decimal=self.decimal,
345389
date_format=self.date_format,
@@ -353,10 +397,23 @@ def _save_chunk(self, start_i: int, end_i: int) -> None:
353397
ix = data_index.to_native_types(
354398
slicer=slicer,
355399
na_rep=self.na_rep,
400+
bytes_encoding=self.bytes_encoding,
356401
float_format=self.float_format,
357402
decimal=self.decimal,
358403
date_format=self.date_format,
359404
quoting=self.quoting,
360405
)
361406

362407
libwriters.write_csv_rows(self.data, ix, self.nlevels, self.cols, self.writer)
408+
409+
def _bytes_to_str(self, values):
410+
"""If all the values are bytes, then modify values list by decoding
411+
bytes to str."""
412+
np_values = np.array(values, dtype=object)
413+
is_all_bytes = lib.is_bytes_array(np_values)
414+
is_any_bytes = lib.is_any_bytes_in_array(np_values)
415+
if is_any_bytes and not is_all_bytes:
416+
raise ValueError("Cannot mix types")
417+
if self.bytes_encoding is not None and is_all_bytes:
418+
for i, value in enumerate(values):
419+
values[i] = value.decode(self.bytes_encoding)

pandas/tests/frame/test_to_csv.py

+85
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,91 @@ def test_to_csv_withcommas(self):
740740
df2 = self.read_csv(path)
741741
tm.assert_frame_equal(df2, df)
742742

743+
def test_to_csv_bytes(self):
744+
# GH 9712
745+
times = date_range("2013-10-27 23:00", "2013-10-28 00:00", freq="H")
746+
df = DataFrame(
747+
{b"hello": [b"abcd", b"world"], b"times": times}, index=[b"A", b"B"]
748+
)
749+
df.loc[b"C"] = np.nan
750+
df.index.name = b"idx"
751+
752+
df_expected = DataFrame(
753+
{"hello": ["abcd", "world"], "times": times}, index=["A", "B"]
754+
)
755+
df_expected.loc["C"] = np.nan
756+
df_expected.index.name = "idx"
757+
758+
with tm.ensure_clean("__tmp_to_csv_bytes__.csv") as path:
759+
df.to_csv(path, header=True)
760+
df_output = self.read_csv(path)
761+
df_output.times = to_datetime(df_output.times)
762+
tm.assert_frame_equal(df_output, df_expected)
763+
764+
non_unicode_byte = b"\xbc\xa6"
765+
non_unicode_decoded = non_unicode_byte.decode("gb18030")
766+
df = DataFrame({non_unicode_byte: [non_unicode_byte, b"world"]})
767+
df.index.name = "idx"
768+
769+
df_expected = DataFrame({non_unicode_decoded: [non_unicode_decoded, "world"]})
770+
df_expected.index.name = "idx"
771+
772+
with tm.ensure_clean("__tmp_to_csv_bytes__.csv") as path:
773+
df.to_csv(path, bytes_encoding="gb18030", header=True)
774+
df_output = self.read_csv(path)
775+
tm.assert_frame_equal(df_output, df_expected)
776+
777+
# decoding error, when transcoding fails
778+
with pytest.raises(UnicodeDecodeError):
779+
df.to_csv(bytes_encoding="utf-8")
780+
781+
# mixing of bytes and non-bytes
782+
df = DataFrame({"hello": [b"abcd", "world"]})
783+
with pytest.raises(ValueError):
784+
df.to_csv()
785+
df = DataFrame({b"hello": ["a", "b"], "world": ["c", "d"]})
786+
with pytest.raises(ValueError):
787+
df.to_csv()
788+
df = DataFrame({"hello": ["a", "b"], "world": ["c", "d"]}, index=["A", b"B"])
789+
with pytest.raises(ValueError):
790+
df.to_csv()
791+
792+
# multi-indexes
793+
iterables = [[b"A", b"B"], ["C", "D"]]
794+
index = pd.MultiIndex.from_product(iterables, names=[b"f", b"s"])
795+
data = np.array([[0, 0], [0, 0], [0, 0], [0, 0]])
796+
df = pd.DataFrame(data, index=index)
797+
798+
with tm.ensure_clean("__tmp_to_csv_bytes__.csv") as path:
799+
df.to_csv(path)
800+
import sys
801+
802+
df.to_csv(sys.stdout)
803+
with open(path) as csvfile:
804+
output = csvfile.readlines()
805+
806+
expected = [
807+
"f,s,0,1\n",
808+
"A,C,0,0\n",
809+
"A,D,0,0\n",
810+
"B,C,0,0\n",
811+
"B,D,0,0\n",
812+
]
813+
assert output == expected
814+
815+
# mixing of bytes and non-bytes in multi-indexes
816+
iterables = [[b"A", "B"], ["C", "D"]]
817+
index = pd.MultiIndex.from_product(iterables)
818+
df = pd.DataFrame(data, index=index)
819+
with pytest.raises(ValueError):
820+
df.to_csv()
821+
822+
iterables = [["A", "B"], ["C", "D"]]
823+
index = pd.MultiIndex.from_product(iterables, names=[b"f", "s"])
824+
df = pd.DataFrame(data, index=index)
825+
with pytest.raises(ValueError):
826+
df.to_csv()
827+
743828
def test_to_csv_mixed(self):
744829
def create_cols(name):
745830
return [f"{name}{i:03d}" for i in range(5)]

0 commit comments

Comments
 (0)