Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4c41b64

Browse files
committedAug 30, 2021
ENH: Add support for nullable boolean and integers in Stata writers
Add code that allows nullable arrays to be written closes #40855
1 parent 303fc9a commit 4c41b64

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed
 

‎doc/source/whatsnew/v1.4.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,14 @@ Other enhancements
9393
- :class:`DataFrameGroupBy` operations with ``as_index=False`` now correctly retain ``ExtensionDtype`` dtypes for columns being grouped on (:issue:`41373`)
9494
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
9595
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
96+
- Additional options added to :meth:`.Styler.bar` to control alignment and display, with keyword only arguments (:issue:`26070`, :issue:`36419`)
97+
- :meth:`Styler.bar` now validates the input argument ``width`` and ``height`` (:issue:`42511`)
9698
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
9799
- :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` now support the argument ``skipna`` (:issue:`34047`)
98100
-
101+
-
102+
- Added support for nullable boolean and integer types in :meth:`DataFrame.to_stata`, :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`40855`)
103+
-
99104

100105
.. ---------------------------------------------------------------------------
101106

‎pandas/io/stata.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
to_timedelta,
6060
)
6161
from pandas.core import generic
62+
from pandas.core.arrays.boolean import BooleanDtype
63+
from pandas.core.arrays.integer import _IntegerDtype
6264
from pandas.core.frame import DataFrame
6365
from pandas.core.indexes.base import Index
6466
from pandas.core.series import Series
@@ -584,14 +586,24 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
584586
(np.uint8, np.int8, np.int16),
585587
(np.uint16, np.int16, np.int32),
586588
(np.uint32, np.int32, np.int64),
589+
(np.uint64, np.int64, np.float64),
587590
)
588591

589592
float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
590593
float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
591594

592595
for col in data:
593-
dtype = data[col].dtype
594596
# Cast from unsupported types to supported types
597+
is_nullable_int = isinstance(data[col].dtype, (_IntegerDtype, BooleanDtype))
598+
orig = data[col]
599+
if is_nullable_int:
600+
missing_loc = data[col].isna()
601+
if missing_loc.any():
602+
# Replace with always safe value
603+
data.loc[missing_loc, col] = 0
604+
# Replace with NumPy-compatible column
605+
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
606+
dtype = data[col].dtype
595607
for c_data in conversion_data:
596608
if dtype == c_data[0]:
597609
if data[col].max() <= np.iinfo(c_data[1]).max:
@@ -633,7 +645,12 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
633645
f"Column {col} has a maximum value ({value}) outside the range "
634646
f"supported by Stata ({float64_max})"
635647
)
636-
648+
if is_nullable_int:
649+
missing = orig.isna()
650+
if missing.any():
651+
# Replace missing by Stata sentinel value
652+
sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
653+
data.loc[missing, col] = sentinel
637654
if ws:
638655
warnings.warn(ws, PossiblePrecisionLoss)
639656

‎pandas/tests/io/test_stata.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,3 +2048,45 @@ def test_stata_compression(compression_only, read_infer, to_infer):
20482048
df.to_stata(path, compression=to_compression)
20492049
result = read_stata(path, compression=read_compression, index_col="index")
20502050
tm.assert_frame_equal(result, df)
2051+
2052+
2053+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
2054+
@pytest.mark.parametrize(
2055+
"dtype",
2056+
[
2057+
pd.BooleanDtype,
2058+
pd.Int8Dtype,
2059+
pd.Int16Dtype,
2060+
pd.Int32Dtype,
2061+
pd.Int64Dtype,
2062+
pd.UInt8Dtype,
2063+
pd.UInt16Dtype,
2064+
pd.UInt32Dtype,
2065+
pd.UInt64Dtype,
2066+
],
2067+
)
2068+
def test_nullable_support(dtype, version):
2069+
df = DataFrame(
2070+
{
2071+
"a": Series([1.0, 2.0, 3.0]),
2072+
"b": Series([1, pd.NA, pd.NA], dtype=dtype.name),
2073+
"c": Series(["a", "b", None]),
2074+
}
2075+
)
2076+
dtype_name = df.b.dtype.numpy_dtype.name
2077+
# Only use supported names: no uint, bool or int64
2078+
dtype_name = dtype_name.replace("u", "")
2079+
if dtype_name == "int64":
2080+
dtype_name = "int32"
2081+
elif dtype_name == "bool":
2082+
dtype_name = "int8"
2083+
value = StataMissingValue.BASE_MISSING_VALUES[dtype_name]
2084+
smv = StataMissingValue(value)
2085+
expected_b = Series([1, smv, smv], dtype=object, name="b")
2086+
expected_c = Series(["a", "b", ""], name="c")
2087+
with tm.ensure_clean() as path:
2088+
df.to_stata(path, write_index=False, version=version)
2089+
reread = read_stata(path, convert_missing=True)
2090+
tm.assert_series_equal(df.a, reread.a)
2091+
tm.assert_series_equal(reread.b, expected_b)
2092+
tm.assert_series_equal(reread.c, expected_c)

0 commit comments

Comments
 (0)
Please sign in to comment.