Skip to content

Commit 3869dc7

Browse files
committed
BUG: Avoids b' prefix for bytes in to_csv() (pandas-dev#9712)
Add a new optional parameter named bytes_encoding to allow a specific encoding scheme to be used to decode the bytes.
1 parent 998e2ab commit 3869dc7

File tree

8 files changed

+195
-8
lines changed

8 files changed

+195
-8
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ I/O
10301030
- Bug in :meth:`read_excel` for ODS files removes 0.0 values (:issue:`27222`)
10311031
- Bug in :meth:`ujson.encode` was raising an `OverflowError` with numbers larger than sys.maxsize (:issue: `34395`)
10321032
- Bug in :meth:`HDFStore.append_to_multiple` was raising a ``ValueError`` when the min_itemsize parameter is set (:issue:`11238`)
1033+
- Bug in :meth:`to_csv` which emitted b'' around bytes. It now has an optional `bytes_encoding` parameter that allows to pass a specific encoding scheme according to which the bytes are decoded. (:issue:`9712`)
10331034

10341035
Plotting
10351036
^^^^^^^^

pandas/_libs/lib.pyx

+17-1
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,17 @@ cdef class Validator:
15581558
else:
15591559
return False
15601560

1561+
cdef bint any(self, ndarray values) except -1:
1562+
if not self.n:
1563+
return False
1564+
cdef:
1565+
Py_ssize_t i
1566+
Py_ssize_t n = self.n
1567+
for i in range(n):
1568+
if self.is_valid(values[i]):
1569+
return True
1570+
return False
1571+
15611572
@cython.wraparound(False)
15621573
@cython.boundscheck(False)
15631574
cdef bint _validate(self, ndarray values) except -1:
@@ -1710,12 +1721,17 @@ cdef class BytesValidator(Validator):
17101721
return issubclass(self.dtype.type, np.bytes_)
17111722

17121723

1713-
cdef bint is_bytes_array(ndarray values, bint skipna=False):
1724+
cpdef bint is_bytes_array(ndarray values, bint skipna=False):
17141725
cdef:
17151726
BytesValidator validator = BytesValidator(len(values), values.dtype,
17161727
skipna=skipna)
17171728
return validator.validate(values)
17181729

1730+
cpdef bint is_any_bytes_in_array(ndarray values, bint skipna=False):
1731+
cdef:
1732+
BytesValidator validator = BytesValidator(len(values), values.dtype,
1733+
skipna=skipna)
1734+
return validator.any(values)
17191735

17201736
cdef class TemporalValidator(Validator):
17211737
cdef:

pandas/core/generic.py

+6
Original file line numberDiff line numberDiff line change
@@ -3000,6 +3000,7 @@ def to_csv(
30003000
index_label: Optional[Union[bool_t, str, Sequence[Label]]] = None,
30013001
mode: str = "w",
30023002
encoding: Optional[str] = None,
3003+
bytes_encoding: Optional[str] = None,
30033004
compression: Optional[Union[str, Mapping[str, str]]] = "infer",
30043005
quoting: Optional[int] = None,
30053006
quotechar: str = '"',
@@ -3057,6 +3058,10 @@ def to_csv(
30573058
encoding : str, optional
30583059
A string representing the encoding to use in the output file,
30593060
defaults to 'utf-8'.
3061+
bytes_encoding : str, optional
3062+
A string representing the encoding to use to decode the bytes
3063+
in the output file, defaults to using the 'encoding' parameter or the
3064+
encoding specified by the file object.
30603065
compression : str or dict, default 'infer'
30613066
If str, represents compression mode. If dict, value at 'method' is
30623067
the compression mode. Compression mode may be any of the following
@@ -3147,6 +3152,7 @@ def to_csv(
31473152
line_terminator=line_terminator,
31483153
sep=sep,
31493154
encoding=encoding,
3155+
bytes_encoding=bytes_encoding,
31503156
errors=errors,
31513157
compression=compression,
31523158
quoting=quoting,

pandas/core/indexes/base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from pandas.core.ops import get_op_result_name
7878
from pandas.core.ops.invalid import make_invalid_op
7979
from pandas.core.sorting import ensure_key_mapped
80-
from pandas.core.strings import StringMethods
80+
from pandas.core.strings import StringMethods, str_decode
8181

8282
from pandas.io.formats.printing import (
8383
PrettyDict,
@@ -954,6 +954,8 @@ def to_native_types(self, slicer=None, **kwargs):
954954
Whether or not there are quoted values in `self`
955955
3) date_format : str
956956
The format used to represent date-like values.
957+
4) bytes_encoding : str
958+
The encoding scheme to use to decode the bytes.
957959
958960
Returns
959961
-------
@@ -965,7 +967,9 @@ def to_native_types(self, slicer=None, **kwargs):
965967
values = values[slicer]
966968
return values._format_native_types(**kwargs)
967969

968-
def _format_native_types(self, na_rep="", quoting=None, **kwargs):
970+
def _format_native_types(
971+
self, na_rep="", quoting=None, bytes_encoding=None, **kwargs
972+
):
969973
"""
970974
Actually format specific types of the index.
971975
"""
@@ -976,6 +980,12 @@ def _format_native_types(self, na_rep="", quoting=None, **kwargs):
976980
values = np.array(self, dtype=object, copy=True)
977981

978982
values[mask] = na_rep
983+
is_all_bytes = lib.is_bytes_array(values, skipna=True)
984+
is_any_bytes = lib.is_any_bytes_in_array(values, skipna=True)
985+
if is_any_bytes and not is_all_bytes:
986+
raise ValueError("Cannot mix types")
987+
if bytes_encoding is not None and is_all_bytes:
988+
values = str_decode(values, bytes_encoding)
979989
return values
980990

981991
def _summary(self, name=None) -> str_t:

pandas/core/internals/blocks.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
)
8484
import pandas.core.missing as missing
8585
from pandas.core.nanops import nanpercentile
86+
from pandas.core.strings import str_decode
8687

8788
if TYPE_CHECKING:
8889
from pandas import Index
@@ -653,13 +654,24 @@ def should_store(self, value: ArrayLike) -> bool:
653654
"""
654655
return is_dtype_equal(value.dtype, self.dtype)
655656

656-
def to_native_types(self, na_rep="nan", quoting=None, **kwargs):
657+
def to_native_types(
658+
self, na_rep="nan", bytes_encoding=None, quoting=None, **kwargs
659+
):
657660
""" convert to our native types format """
658661
values = self.values
659662

660663
mask = isna(values)
661664
itemsize = writers.word_len(na_rep)
662665

666+
length = values.shape[0]
667+
for i in range(length):
668+
is_all_bytes = lib.is_bytes_array(values[i], skipna=True)
669+
is_any_bytes = lib.is_any_bytes_in_array(values[i])
670+
if is_any_bytes and not is_all_bytes:
671+
raise ValueError("Cannot mix types")
672+
if bytes_encoding is not None and is_all_bytes:
673+
values[i] = str_decode(values[i], bytes_encoding)
674+
663675
if not self.is_object and not quoting and itemsize:
664676
values = values.astype(str)
665677
if values.dtype.itemsize / np.dtype("U1").itemsize < itemsize:

pandas/io/formats/csvs.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313

14-
from pandas._libs import writers as libwriters
14+
from pandas._libs import writers as libwriters, lib
1515
from pandas._typing import FilePathOrBuffer
1616

1717
from pandas.core.dtypes.generic import (
@@ -30,6 +30,23 @@
3030
)
3131

3232

33+
class EncodingConflictWarning(Warning):
34+
pass
35+
36+
37+
encoding_conflict_doc = """
38+
the encoding scheme: [%s] with which the the existing file object is opened \
39+
conflicted with the encoding scheme: [%s] mentioned in the .to_csv method. \
40+
Will be using encoding scheme mentioned by the file object that is [%s].
41+
"""
42+
43+
44+
def _mismatch_encoding(encoding, path_or_buf_encoding):
45+
if encoding is None or path_or_buf_encoding is None:
46+
return False
47+
return encoding != path_or_buf_encoding
48+
49+
3350
class CSVFormatter:
3451
def __init__(
3552
self,
@@ -44,6 +61,7 @@ def __init__(
4461
index_label: Optional[Union[bool, Hashable, Sequence[Hashable]]] = None,
4562
mode: str = "w",
4663
encoding: Optional[str] = None,
64+
bytes_encoding: Optional[str] = None,
4765
errors: str = "strict",
4866
compression: Union[str, Mapping[str, str], None] = "infer",
4967
quoting: Optional[int] = None,
@@ -75,12 +93,32 @@ def __init__(
7593
self.index = index
7694
self.index_label = index_label
7795
self.mode = mode
78-
if encoding is None:
79-
encoding = "utf-8"
96+
97+
if hasattr(self.path_or_buf, "encoding"):
98+
if _mismatch_encoding(encoding, self.path_or_buf.encoding):
99+
ws = encoding_conflict_doc % (
100+
self.path_or_buf.encoding,
101+
encoding,
102+
self.path_or_buf.encoding,
103+
)
104+
warnings.warn(ws, EncodingConflictWarning, stacklevel=2)
105+
if self.path_or_buf.encoding is None:
106+
encoding = "utf-8"
107+
else:
108+
encoding = self.path_or_buf.encoding
109+
else:
110+
if encoding is None:
111+
encoding = "utf-8"
112+
80113
self.encoding = encoding
81114
self.errors = errors
82115
self.compression = infer_compression(self.path_or_buf, compression)
83116

117+
if bytes_encoding is None:
118+
bytes_encoding = self.encoding
119+
120+
self.bytes_encoding = bytes_encoding
121+
84122
if quoting is None:
85123
quoting = csvlib.QUOTE_MINIMAL
86124
self.quoting = quoting
@@ -108,6 +146,7 @@ def __init__(
108146
if isinstance(cols, ABCIndexClass):
109147
cols = cols.to_native_types(
110148
na_rep=na_rep,
149+
bytes_encoding=bytes_encoding,
111150
float_format=float_format,
112151
date_format=date_format,
113152
quoting=self.quoting,
@@ -122,6 +161,7 @@ def __init__(
122161
if isinstance(cols, ABCIndexClass):
123162
cols = cols.to_native_types(
124163
na_rep=na_rep,
164+
bytes_encoding=bytes_encoding,
125165
float_format=float_format,
126166
date_format=date_format,
127167
quoting=self.quoting,
@@ -278,6 +318,8 @@ def _save_header(self):
278318
else:
279319
encoded_labels = []
280320

321+
self._bytes_to_str(encoded_labels)
322+
281323
if not has_mi_columns or has_aliases:
282324
encoded_labels += list(write_cols)
283325
writer.writerow(encoded_labels)
@@ -300,6 +342,7 @@ def _save_header(self):
300342
col_line.extend([""] * (len(index_label) - 1))
301343

302344
col_line.extend(columns._get_level_values(i))
345+
self._bytes_to_str(col_line)
303346

304347
writer.writerow(col_line)
305348

@@ -340,6 +383,7 @@ def _save_chunk(self, start_i: int, end_i: int) -> None:
340383
b = blocks[i]
341384
d = b.to_native_types(
342385
na_rep=self.na_rep,
386+
bytes_encoding=self.bytes_encoding,
343387
float_format=self.float_format,
344388
decimal=self.decimal,
345389
date_format=self.date_format,
@@ -353,10 +397,23 @@ def _save_chunk(self, start_i: int, end_i: int) -> None:
353397
ix = data_index.to_native_types(
354398
slicer=slicer,
355399
na_rep=self.na_rep,
400+
bytes_encoding=self.bytes_encoding,
356401
float_format=self.float_format,
357402
decimal=self.decimal,
358403
date_format=self.date_format,
359404
quoting=self.quoting,
360405
)
361406

362407
libwriters.write_csv_rows(self.data, ix, self.nlevels, self.cols, self.writer)
408+
409+
def _bytes_to_str(self, values):
410+
"""If all the values are bytes, then modify values list by decoding
411+
bytes to str."""
412+
np_values = np.array(values, dtype=object)
413+
is_all_bytes = lib.is_bytes_array(np_values)
414+
is_any_bytes = lib.is_any_bytes_in_array(np_values)
415+
if is_any_bytes and not is_all_bytes:
416+
raise ValueError("Cannot mix types")
417+
if self.bytes_encoding is not None and is_all_bytes:
418+
for i, value in enumerate(values):
419+
values[i] = value.decode(self.bytes_encoding)

pandas/tests/frame/test_to_csv.py

+85
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,91 @@ def test_to_csv_withcommas(self):
740740
df2 = self.read_csv(path)
741741
tm.assert_frame_equal(df2, df)
742742

743+
def test_to_csv_bytes(self):
744+
# GH 9712
745+
times = date_range("2013-10-27 23:00", "2013-10-28 00:00", freq="H")
746+
df = DataFrame(
747+
{b"hello": [b"abcd", b"world"], b"times": times}, index=[b"A", b"B"]
748+
)
749+
df.loc[b"C"] = np.nan
750+
df.index.name = b"idx"
751+
752+
df_expected = DataFrame(
753+
{"hello": ["abcd", "world"], "times": times}, index=["A", "B"]
754+
)
755+
df_expected.loc["C"] = np.nan
756+
df_expected.index.name = "idx"
757+
758+
with tm.ensure_clean("__tmp_to_csv_bytes__.csv") as path:
759+
df.to_csv(path, header=True)
760+
df_output = self.read_csv(path)
761+
df_output.times = to_datetime(df_output.times)
762+
tm.assert_frame_equal(df_output, df_expected)
763+
764+
non_unicode_byte = b"\xbc\xa6"
765+
non_unicode_decoded = non_unicode_byte.decode("gb18030")
766+
df = DataFrame({non_unicode_byte: [non_unicode_byte, b"world"]})
767+
df.index.name = "idx"
768+
769+
df_expected = DataFrame({non_unicode_decoded: [non_unicode_decoded, "world"]})
770+
df_expected.index.name = "idx"
771+
772+
with tm.ensure_clean("__tmp_to_csv_bytes__.csv") as path:
773+
df.to_csv(path, bytes_encoding="gb18030", header=True)
774+
df_output = self.read_csv(path)
775+
tm.assert_frame_equal(df_output, df_expected)
776+
777+
# decoding error, when transcoding fails
778+
with pytest.raises(UnicodeDecodeError):
779+
df.to_csv(bytes_encoding="utf-8")
780+
781+
# mixing of bytes and non-bytes
782+
df = DataFrame({"hello": [b"abcd", "world"]})
783+
with pytest.raises(ValueError):
784+
df.to_csv()
785+
df = DataFrame({b"hello": ["a", "b"], "world": ["c", "d"]})
786+
with pytest.raises(ValueError):
787+
df.to_csv()
788+
df = DataFrame({"hello": ["a", "b"], "world": ["c", "d"]}, index=["A", b"B"])
789+
with pytest.raises(ValueError):
790+
df.to_csv()
791+
792+
# multi-indexes
793+
iterables = [[b"A", b"B"], ["C", "D"]]
794+
index = pd.MultiIndex.from_product(iterables, names=[b"f", b"s"])
795+
data = np.array([[0, 0], [0, 0], [0, 0], [0, 0]])
796+
df = pd.DataFrame(data, index=index)
797+
798+
with tm.ensure_clean("__tmp_to_csv_bytes__.csv") as path:
799+
df.to_csv(path)
800+
import sys
801+
802+
df.to_csv(sys.stdout)
803+
with open(path) as csvfile:
804+
output = csvfile.readlines()
805+
806+
expected = [
807+
"f,s,0,1\n",
808+
"A,C,0,0\n",
809+
"A,D,0,0\n",
810+
"B,C,0,0\n",
811+
"B,D,0,0\n",
812+
]
813+
assert output == expected
814+
815+
# mixing of bytes and non-bytes in multi-indexes
816+
iterables = [[b"A", "B"], ["C", "D"]]
817+
index = pd.MultiIndex.from_product(iterables)
818+
df = pd.DataFrame(data, index=index)
819+
with pytest.raises(ValueError):
820+
df.to_csv()
821+
822+
iterables = [["A", "B"], ["C", "D"]]
823+
index = pd.MultiIndex.from_product(iterables, names=[b"f", "s"])
824+
df = pd.DataFrame(data, index=index)
825+
with pytest.raises(ValueError):
826+
df.to_csv()
827+
743828
def test_to_csv_mixed(self):
744829
def create_cols(name):
745830
return [f"{name}{i:03d}" for i in range(5)]

pandas/tests/io/formats/test_to_csv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def test_to_csv_stdout_file(self, capsys):
444444
expected_rows = [",name_1,name_2", "0,foo,bar", "1,baz,qux"]
445445
expected_ascii = tm.convert_rows_list_to_csv_str(expected_rows)
446446

447-
df.to_csv(sys.stdout, encoding="ascii")
447+
df.to_csv(sys.stdout, encoding="utf-8")
448448
captured = capsys.readouterr()
449449

450450
assert captured.out == expected_ascii

0 commit comments

Comments
 (0)