Skip to content

Commit b01bbe2

Browse files
committed
ENH: Add support for nullable boolean and integers in Stata writers
Add code that allows nullable arrays to be written closes pandas-dev#40855
1 parent a6943ae commit b01bbe2

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

doc/source/whatsnew/v1.4.0.rst

+4
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,13 @@ Other enhancements
9494
- :class:`DataFrameGroupBy` operations with ``as_index=False`` now correctly retain ``ExtensionDtype`` dtypes for columns being grouped on (:issue:`41373`)
9595
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
9696
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
97+
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)
98+
- :meth:`Styler.bar` now validates the input argument ``width`` and ``height`` (:issue:`42511`)
9799
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
98100
- :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` now support the argument ``skipna`` (:issue:`34047`)
99101
- :meth:`read_table` now supports the argument ``storage_options`` (:issue:`39167`)
102+
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
103+
-
100104

101105
.. ---------------------------------------------------------------------------
102106

pandas/io/stata.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
to_timedelta,
5959
)
6060
from pandas.core import generic
61+
from pandas.core.arrays.boolean import BooleanDtype
62+
from pandas.core.arrays.integer import _IntegerDtype
6163
from pandas.core.frame import DataFrame
6264
from pandas.core.indexes.base import Index
6365
from pandas.core.series import Series
@@ -583,14 +585,24 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
583585
(np.uint8, np.int8, np.int16),
584586
(np.uint16, np.int16, np.int32),
585587
(np.uint32, np.int32, np.int64),
588+
(np.uint64, np.int64, np.float64),
586589
)
587590

588591
float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
589592
float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
590593

591594
for col in data:
592-
dtype = data[col].dtype
593595
# Cast from unsupported types to supported types
596+
is_nullable_int = isinstance(data[col].dtype, (_IntegerDtype, BooleanDtype))
597+
orig = data[col]
598+
if is_nullable_int:
599+
missing_loc = data[col].isna()
600+
if missing_loc.any():
601+
# Replace with always safe value
602+
data.loc[missing_loc, col] = 0
603+
# Replace with NumPy-compatible column
604+
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
605+
dtype = data[col].dtype
594606
for c_data in conversion_data:
595607
if dtype == c_data[0]:
596608
if data[col].max() <= np.iinfo(c_data[1]).max:
@@ -632,7 +644,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
632644
f"Column {col} has a maximum value ({value}) outside the range "
633645
f"supported by Stata ({float64_max})"
634646
)
635-
647+
if is_nullable_int:
648+
missing = orig.isna()
649+
if missing.any():
650+
# Replace missing by Stata sentinel value
651+
sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
652+
data.loc[missing, col] = sentinel
636653
if ws:
637654
warnings.warn(ws, PossiblePrecisionLoss)
638655

pandas/tests/io/test_stata.py

+42
Original file line numberDiff line numberDiff line change
@@ -2048,3 +2048,45 @@ def test_stata_compression(compression_only, read_infer, to_infer):
20482048
df.to_stata(path, compression=to_compression)
20492049
result = read_stata(path, compression=read_compression, index_col="index")
20502050
tm.assert_frame_equal(result, df)
2051+
2052+
2053+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
2054+
@pytest.mark.parametrize(
2055+
"dtype",
2056+
[
2057+
pd.BooleanDtype,
2058+
pd.Int8Dtype,
2059+
pd.Int16Dtype,
2060+
pd.Int32Dtype,
2061+
pd.Int64Dtype,
2062+
pd.UInt8Dtype,
2063+
pd.UInt16Dtype,
2064+
pd.UInt32Dtype,
2065+
pd.UInt64Dtype,
2066+
],
2067+
)
2068+
def test_nullable_support(dtype, version):
2069+
df = DataFrame(
2070+
{
2071+
"a": Series([1.0, 2.0, 3.0]),
2072+
"b": Series([1, pd.NA, pd.NA], dtype=dtype.name),
2073+
"c": Series(["a", "b", None]),
2074+
}
2075+
)
2076+
dtype_name = df.b.dtype.numpy_dtype.name
2077+
# Only use supported names: no uint, bool or int64
2078+
dtype_name = dtype_name.replace("u", "")
2079+
if dtype_name == "int64":
2080+
dtype_name = "int32"
2081+
elif dtype_name == "bool":
2082+
dtype_name = "int8"
2083+
value = StataMissingValue.BASE_MISSING_VALUES[dtype_name]
2084+
smv = StataMissingValue(value)
2085+
expected_b = Series([1, smv, smv], dtype=object, name="b")
2086+
expected_c = Series(["a", "b", ""], name="c")
2087+
with tm.ensure_clean() as path:
2088+
df.to_stata(path, write_index=False, version=version)
2089+
reread = read_stata(path, convert_missing=True)
2090+
tm.assert_series_equal(df.a, reread.a)
2091+
tm.assert_series_equal(reread.b, expected_b)
2092+
tm.assert_series_equal(reread.c, expected_c)

0 commit comments

Comments
 (0)