Skip to content

Commit 8f5db02

Browse files
committed
ENH: Add support for nullable boolean and integers in Stata writers
Add code that allows nullable arrays to be written closes #40855
1 parent 9f90bd4 commit 8f5db02

File tree

3 files changed

+65
-3
lines changed

3 files changed

+65
-3
lines changed

doc/source/whatsnew/v1.4.0.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,17 @@ Other enhancements
116116
- :class:`DataFrameGroupBy` operations with ``as_index=False`` now correctly retain ``ExtensionDtype`` dtypes for columns being grouped on (:issue:`41373`)
117117
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
118118
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
119+
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)
120+
- :meth:`Styler.bar` now validates the input argument ``width`` and ``height`` (:issue:`42511`)
119121
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
120122
- :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` now support the argument ``skipna`` (:issue:`34047`)
121123
- :meth:`read_table` now supports the argument ``storage_options`` (:issue:`39167`)
122124
- :meth:`DataFrame.to_stata` and :meth:`StataWriter` now accept the keyword only argument ``value_labels`` to save labels for non-categorical columns
123125
- Methods that relied on hashmap based algos such as :meth:`DataFrameGroupBy.value_counts`, :meth:`DataFrameGroupBy.count` and :func:`factorize` ignored imaginary component for complex numbers (:issue:`17927`)
124126
- Add :meth:`Series.str.removeprefix` and :meth:`Series.str.removesuffix` introduced in Python 3.9 to remove pre-/suffixes from string-type :class:`Series` (:issue:`36944`)
125127
- Attempting to write into a file in missing parent directory with :meth:`DataFrame.to_csv`, :meth:`DataFrame.to_html`, :meth:`DataFrame.to_excel`, :meth:`DataFrame.to_feather`, :meth:`DataFrame.to_parquet`, :meth:`DataFrame.to_stata`, :meth:`DataFrame.to_json`, :meth:`DataFrame.to_pickle`, and :meth:`DataFrame.to_xml` now explicitly mentions missing parent directory, the same is true for :class:`Series` counterparts (:issue:`24306`)
126-
128+
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
129+
-
127130

128131
.. ---------------------------------------------------------------------------
129132

pandas/io/stata.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
to_timedelta,
6161
)
6262
from pandas.core import generic
63+
from pandas.core.arrays.boolean import BooleanDtype
64+
from pandas.core.arrays.integer import _IntegerDtype
6365
from pandas.core.frame import DataFrame
6466
from pandas.core.indexes.base import Index
6567
from pandas.core.series import Series
@@ -588,14 +590,24 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
588590
(np.uint8, np.int8, np.int16),
589591
(np.uint16, np.int16, np.int32),
590592
(np.uint32, np.int32, np.int64),
593+
(np.uint64, np.int64, np.float64),
591594
)
592595

593596
float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
594597
float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
595598

596599
for col in data:
597-
dtype = data[col].dtype
598600
# Cast from unsupported types to supported types
601+
is_nullable_int = isinstance(data[col].dtype, (_IntegerDtype, BooleanDtype))
602+
orig = data[col]
603+
if is_nullable_int:
604+
missing_loc = data[col].isna()
605+
if missing_loc.any():
606+
# Replace with always safe value
607+
data.loc[missing_loc, col] = 0
608+
# Replace with NumPy-compatible column
609+
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
610+
dtype = data[col].dtype
599611
for c_data in conversion_data:
600612
if dtype == c_data[0]:
601613
if data[col].max() <= np.iinfo(c_data[1]).max:
@@ -637,7 +649,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
637649
f"Column {col} has a maximum value ({value}) outside the range "
638650
f"supported by Stata ({float64_max})"
639651
)
640-
652+
if is_nullable_int:
653+
missing = orig.isna()
654+
if missing.any():
655+
# Replace missing by Stata sentinel value
656+
sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
657+
data.loc[missing, col] = sentinel
641658
if ws:
642659
warnings.warn(ws, PossiblePrecisionLoss)
643660

pandas/tests/io/test_stata.py

+42
Original file line numberDiff line numberDiff line change
@@ -2162,3 +2162,45 @@ def test_non_categorical_value_label_convert_categoricals_error():
21622162
"""
21632163
with pytest.raises(ValueError, match=msg):
21642164
read_stata(path, convert_categoricals=True)
2165+
2166+
2167+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
2168+
@pytest.mark.parametrize(
2169+
"dtype",
2170+
[
2171+
pd.BooleanDtype,
2172+
pd.Int8Dtype,
2173+
pd.Int16Dtype,
2174+
pd.Int32Dtype,
2175+
pd.Int64Dtype,
2176+
pd.UInt8Dtype,
2177+
pd.UInt16Dtype,
2178+
pd.UInt32Dtype,
2179+
pd.UInt64Dtype,
2180+
],
2181+
)
2182+
def test_nullable_support(dtype, version):
2183+
df = DataFrame(
2184+
{
2185+
"a": Series([1.0, 2.0, 3.0]),
2186+
"b": Series([1, pd.NA, pd.NA], dtype=dtype.name),
2187+
"c": Series(["a", "b", None]),
2188+
}
2189+
)
2190+
dtype_name = df.b.dtype.numpy_dtype.name
2191+
# Only use supported names: no uint, bool or int64
2192+
dtype_name = dtype_name.replace("u", "")
2193+
if dtype_name == "int64":
2194+
dtype_name = "int32"
2195+
elif dtype_name == "bool":
2196+
dtype_name = "int8"
2197+
value = StataMissingValue.BASE_MISSING_VALUES[dtype_name]
2198+
smv = StataMissingValue(value)
2199+
expected_b = Series([1, smv, smv], dtype=object, name="b")
2200+
expected_c = Series(["a", "b", ""], name="c")
2201+
with tm.ensure_clean() as path:
2202+
df.to_stata(path, write_index=False, version=version)
2203+
reread = read_stata(path, convert_missing=True)
2204+
tm.assert_series_equal(df.a, reread.a)
2205+
tm.assert_series_equal(reread.b, expected_b)
2206+
tm.assert_series_equal(reread.c, expected_c)

0 commit comments

Comments
 (0)