Skip to content

Commit 58c9ef7

Browse files
Backport PR #56771 on branch 2.2.x (BUG: to_stata not handling ea dtypes correctly) (#56783)
Backport PR #56771: BUG: to_stata not handling ea dtypes correctly Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 6dbeeb4 commit 58c9ef7

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,7 @@ I/O
848848
- Bug in :func:`read_json` not handling dtype conversion properly if ``infer_string`` is set (:issue:`56195`)
849849
- Bug in :meth:`DataFrame.to_excel`, with ``OdsWriter`` (``ods`` files) writing Boolean/string value (:issue:`54994`)
850850
- Bug in :meth:`DataFrame.to_hdf` and :func:`read_hdf` with ``datetime64`` dtypes with non-nanosecond resolution failing to round-trip correctly (:issue:`55622`)
851+
- Bug in :meth:`DataFrame.to_stata` raising for extension dtypes (:issue:`54671`)
851852
- Bug in :meth:`~pandas.read_excel` with ``engine="odf"`` (``ods`` files) when a string cell contains an annotation (:issue:`55200`)
852853
- Bug in :meth:`~pandas.read_excel` with an ODS file without cached formatted cell for float values (:issue:`55219`)
853854
- Bug where :meth:`DataFrame.to_json` would raise an ``OverflowError`` instead of a ``TypeError`` with unsupported NumPy types (:issue:`55403`)

pandas/io/stata.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@
4747
)
4848
from pandas.util._exceptions import find_stack_level
4949

50+
from pandas.core.dtypes.base import ExtensionDtype
5051
from pandas.core.dtypes.common import (
5152
ensure_object,
5253
is_numeric_dtype,
54+
is_string_dtype,
5355
)
5456
from pandas.core.dtypes.dtypes import CategoricalDtype
5557

@@ -62,8 +64,6 @@
6264
to_datetime,
6365
to_timedelta,
6466
)
65-
from pandas.core.arrays.boolean import BooleanDtype
66-
from pandas.core.arrays.integer import IntegerDtype
6767
from pandas.core.frame import DataFrame
6868
from pandas.core.indexes.base import Index
6969
from pandas.core.indexes.range import RangeIndex
@@ -591,17 +591,22 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
591591

592592
for col in data:
593593
# Cast from unsupported types to supported types
594-
is_nullable_int = isinstance(data[col].dtype, (IntegerDtype, BooleanDtype))
594+
is_nullable_int = (
595+
isinstance(data[col].dtype, ExtensionDtype)
596+
and data[col].dtype.kind in "iub"
597+
)
595598
# We need to find orig_missing before altering data below
596599
orig_missing = data[col].isna()
597600
if is_nullable_int:
598-
missing_loc = data[col].isna()
599-
if missing_loc.any():
600-
# Replace with always safe value
601-
fv = 0 if isinstance(data[col].dtype, IntegerDtype) else False
602-
data.loc[missing_loc, col] = fv
601+
fv = 0 if data[col].dtype.kind in "iu" else False
603602
# Replace with NumPy-compatible column
604-
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
603+
data[col] = data[col].fillna(fv).astype(data[col].dtype.numpy_dtype)
604+
elif isinstance(data[col].dtype, ExtensionDtype):
605+
if getattr(data[col].dtype, "numpy_dtype", None) is not None:
606+
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
607+
elif is_string_dtype(data[col].dtype):
608+
data[col] = data[col].astype("object")
609+
605610
dtype = data[col].dtype
606611
empty_df = data.shape[0] == 0
607612
for c_data in conversion_data:

pandas/tests/io/test_stata.py

+37
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
import pytest
1313

14+
import pandas.util._test_decorators as td
15+
1416
import pandas as pd
1517
from pandas import CategoricalDtype
1618
import pandas._testing as tm
@@ -1921,6 +1923,41 @@ def test_writer_118_exceptions(self):
19211923
with pytest.raises(ValueError, match="You must use version 119"):
19221924
StataWriterUTF8(path, df, version=118)
19231925

1926+
@pytest.mark.parametrize(
1927+
"dtype_backend",
1928+
["numpy_nullable", pytest.param("pyarrow", marks=td.skip_if_no("pyarrow"))],
1929+
)
1930+
def test_read_write_ea_dtypes(self, dtype_backend):
1931+
df = DataFrame(
1932+
{
1933+
"a": [1, 2, None],
1934+
"b": ["a", "b", "c"],
1935+
"c": [True, False, None],
1936+
"d": [1.5, 2.5, 3.5],
1937+
"e": pd.date_range("2020-12-31", periods=3, freq="D"),
1938+
},
1939+
index=pd.Index([0, 1, 2], name="index"),
1940+
)
1941+
df = df.convert_dtypes(dtype_backend=dtype_backend)
1942+
df.to_stata("test_stata.dta", version=118)
1943+
1944+
with tm.ensure_clean() as path:
1945+
df.to_stata(path)
1946+
written_and_read_again = self.read_dta(path)
1947+
1948+
expected = DataFrame(
1949+
{
1950+
"a": [1, 2, np.nan],
1951+
"b": ["a", "b", "c"],
1952+
"c": [1.0, 0, np.nan],
1953+
"d": [1.5, 2.5, 3.5],
1954+
"e": pd.date_range("2020-12-31", periods=3, freq="D"),
1955+
},
1956+
index=pd.Index([0, 1, 2], name="index", dtype=np.int32),
1957+
)
1958+
1959+
tm.assert_frame_equal(written_and_read_again.set_index("index"), expected)
1960+
19241961

19251962
@pytest.mark.parametrize("version", [105, 108, 111, 113, 114])
19261963
def test_backward_compat(version, datapath):

0 commit comments

Comments
 (0)