Skip to content

Commit 680ef1c

Browse files
committed
ENH: Add support for nullable boolean and integers in Stata writers
Add code that allows nullable arrays to be written closes #40855
1 parent 1cbf344 commit 680ef1c

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

doc/source/whatsnew/v1.4.0.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ Other enhancements
3131
^^^^^^^^^^^^^^^^^^
3232
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
3333
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
34-
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)
34+
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)
3535
- :meth:`Styler.bar` now validates the input argument ``width`` and ``height`` (:issue:`42511`)
3636
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
37-
-
37+
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
3838

3939
.. ---------------------------------------------------------------------------
4040

pandas/io/stata.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
to_timedelta,
6060
)
6161
from pandas.core import generic
62+
from pandas.core.arrays.boolean import BooleanDtype
63+
from pandas.core.arrays.integer import _IntegerDtype
6264
from pandas.core.frame import DataFrame
6365
from pandas.core.indexes.base import Index
6466
from pandas.core.series import Series
@@ -569,14 +571,24 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
569571
(np.uint8, np.int8, np.int16),
570572
(np.uint16, np.int16, np.int32),
571573
(np.uint32, np.int32, np.int64),
574+
(np.uint64, np.int64, np.float64),
572575
)
573576

574577
float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
575578
float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
576579

577580
for col in data:
578-
dtype = data[col].dtype
579581
# Cast from unsupported types to supported types
582+
is_nullable_int = isinstance(data[col].dtype, (_IntegerDtype, BooleanDtype))
583+
orig = data[col]
584+
if is_nullable_int:
585+
missing_loc = data[col].isna()
586+
if missing_loc.any():
587+
# Replace with always safe value
588+
data.loc[missing_loc, col] = 0
589+
# Replace with NumPy-compatible column
590+
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
591+
dtype = data[col].dtype
580592
for c_data in conversion_data:
581593
if dtype == c_data[0]:
582594
if data[col].max() <= np.iinfo(c_data[1]).max:
@@ -618,7 +630,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
618630
f"Column {col} has a maximum value ({value}) outside the range "
619631
f"supported by Stata ({float64_max})"
620632
)
621-
633+
if is_nullable_int:
634+
missing = orig.isna()
635+
if missing.any():
636+
# Replace missing by Stata sentinel value
637+
sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
638+
data.loc[missing, col] = sentinel
622639
if ws:
623640
warnings.warn(ws, PossiblePrecisionLoss)
624641

pandas/tests/io/test_stata.py

+42
Original file line numberDiff line numberDiff line change
@@ -2047,3 +2047,45 @@ def test_stata_compression(compression_only, read_infer, to_infer):
20472047
df.to_stata(path, compression=to_compression)
20482048
result = read_stata(path, compression=read_compression, index_col="index")
20492049
tm.assert_frame_equal(result, df)
2050+
2051+
2052+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
2053+
@pytest.mark.parametrize(
2054+
"dtype",
2055+
[
2056+
pd.BooleanDtype,
2057+
pd.Int8Dtype,
2058+
pd.Int16Dtype,
2059+
pd.Int32Dtype,
2060+
pd.Int64Dtype,
2061+
pd.UInt8Dtype,
2062+
pd.UInt16Dtype,
2063+
pd.UInt32Dtype,
2064+
pd.UInt64Dtype,
2065+
],
2066+
)
2067+
def test_nullable_support(dtype, version):
2068+
df = DataFrame(
2069+
{
2070+
"a": Series([1.0, 2.0, 3.0]),
2071+
"b": Series([1, pd.NA, pd.NA], dtype=dtype.name),
2072+
"c": Series(["a", "b", None]),
2073+
}
2074+
)
2075+
dtype_name = df.b.dtype.numpy_dtype.name
2076+
# Only use supported names: no uint, bool or int64
2077+
dtype_name = dtype_name.replace("u", "")
2078+
if dtype_name == "int64":
2079+
dtype_name = "int32"
2080+
elif dtype_name == "bool":
2081+
dtype_name = "int8"
2082+
value = StataMissingValue.BASE_MISSING_VALUES[dtype_name]
2083+
smv = StataMissingValue(value)
2084+
expected_b = Series([1, smv, smv], dtype=object, name="b")
2085+
expected_c = Series(["a", "b", ""], name="c")
2086+
with tm.ensure_clean() as path:
2087+
df.to_stata(path, write_index=False, version=version)
2088+
reread = read_stata(path, convert_missing=True)
2089+
tm.assert_series_equal(df.a, reread.a)
2090+
tm.assert_series_equal(reread.b, expected_b)
2091+
tm.assert_series_equal(reread.c, expected_c)

0 commit comments

Comments
 (0)