Skip to content

Commit 587f1db

Browse files
authored
ENH: Add support for nullable boolean and integers in Stata writers (#42565)
1 parent c74cf3a commit 587f1db

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

doc/source/whatsnew/v1.4.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ Other enhancements
123123
- Methods that relied on hashmap based algos such as :meth:`DataFrameGroupBy.value_counts`, :meth:`DataFrameGroupBy.count` and :func:`factorize` ignored imaginary component for complex numbers (:issue:`17927`)
124124
- Add :meth:`Series.str.removeprefix` and :meth:`Series.str.removesuffix` introduced in Python 3.9 to remove pre-/suffixes from string-type :class:`Series` (:issue:`36944`)
125125
- Attempting to write into a file in missing parent directory with :meth:`DataFrame.to_csv`, :meth:`DataFrame.to_html`, :meth:`DataFrame.to_excel`, :meth:`DataFrame.to_feather`, :meth:`DataFrame.to_parquet`, :meth:`DataFrame.to_stata`, :meth:`DataFrame.to_json`, :meth:`DataFrame.to_pickle`, and :meth:`DataFrame.to_xml` now explicitly mentions missing parent directory, the same is true for :class:`Series` counterparts (:issue:`24306`)
126-
126+
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
127+
-
127128

128129
.. ---------------------------------------------------------------------------
129130

pandas/io/stata.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
to_timedelta,
6161
)
6262
from pandas.core import generic
63+
from pandas.core.arrays.boolean import BooleanDtype
64+
from pandas.core.arrays.integer import _IntegerDtype
6365
from pandas.core.frame import DataFrame
6466
from pandas.core.indexes.base import Index
6567
from pandas.core.series import Series
@@ -588,14 +590,24 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
588590
(np.uint8, np.int8, np.int16),
589591
(np.uint16, np.int16, np.int32),
590592
(np.uint32, np.int32, np.int64),
593+
(np.uint64, np.int64, np.float64),
591594
)
592595

593596
float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
594597
float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
595598

596599
for col in data:
597-
dtype = data[col].dtype
598600
# Cast from unsupported types to supported types
601+
is_nullable_int = isinstance(data[col].dtype, (_IntegerDtype, BooleanDtype))
602+
orig = data[col]
603+
if is_nullable_int:
604+
missing_loc = data[col].isna()
605+
if missing_loc.any():
606+
# Replace with always safe value
607+
data.loc[missing_loc, col] = 0
608+
# Replace with NumPy-compatible column
609+
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
610+
dtype = data[col].dtype
599611
for c_data in conversion_data:
600612
if dtype == c_data[0]:
601613
if data[col].max() <= np.iinfo(c_data[1]).max:
@@ -637,7 +649,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
637649
f"Column {col} has a maximum value ({value}) outside the range "
638650
f"supported by Stata ({float64_max})"
639651
)
640-
652+
if is_nullable_int:
653+
missing = orig.isna()
654+
if missing.any():
655+
# Replace missing by Stata sentinel value
656+
sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
657+
data.loc[missing, col] = sentinel
641658
if ws:
642659
warnings.warn(ws, PossiblePrecisionLoss)
643660

pandas/tests/io/test_stata.py

+42
Original file line numberDiff line numberDiff line change
@@ -2162,3 +2162,45 @@ def test_non_categorical_value_label_convert_categoricals_error():
21622162
"""
21632163
with pytest.raises(ValueError, match=msg):
21642164
read_stata(path, convert_categoricals=True)
2165+
2166+
2167+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
2168+
@pytest.mark.parametrize(
2169+
"dtype",
2170+
[
2171+
pd.BooleanDtype,
2172+
pd.Int8Dtype,
2173+
pd.Int16Dtype,
2174+
pd.Int32Dtype,
2175+
pd.Int64Dtype,
2176+
pd.UInt8Dtype,
2177+
pd.UInt16Dtype,
2178+
pd.UInt32Dtype,
2179+
pd.UInt64Dtype,
2180+
],
2181+
)
2182+
def test_nullable_support(dtype, version):
2183+
df = DataFrame(
2184+
{
2185+
"a": Series([1.0, 2.0, 3.0]),
2186+
"b": Series([1, pd.NA, pd.NA], dtype=dtype.name),
2187+
"c": Series(["a", "b", None]),
2188+
}
2189+
)
2190+
dtype_name = df.b.dtype.numpy_dtype.name
2191+
# Only use supported names: no uint, bool or int64
2192+
dtype_name = dtype_name.replace("u", "")
2193+
if dtype_name == "int64":
2194+
dtype_name = "int32"
2195+
elif dtype_name == "bool":
2196+
dtype_name = "int8"
2197+
value = StataMissingValue.BASE_MISSING_VALUES[dtype_name]
2198+
smv = StataMissingValue(value)
2199+
expected_b = Series([1, smv, smv], dtype=object, name="b")
2200+
expected_c = Series(["a", "b", ""], name="c")
2201+
with tm.ensure_clean() as path:
2202+
df.to_stata(path, write_index=False, version=version)
2203+
reread = read_stata(path, convert_missing=True)
2204+
tm.assert_series_equal(df.a, reread.a)
2205+
tm.assert_series_equal(reread.b, expected_b)
2206+
tm.assert_series_equal(reread.c, expected_c)

0 commit comments

Comments
 (0)