Skip to content

Commit fd151ba

Browse files
lmcindewarjreback
andauthored
ENH: option to export df to Stata dataset with value labels (#41042)
Removing unnecessary list comprehension, flake8 Adding value_labels argument to DataFrame to_stata method Updating types and changing ValueError to KeyError for missing column Using converted names for invalid Stata variable names Moving value_labels to key word only for to_stata Adding tests for invalid Stata names and repeated value labels Fixing Literal import Moving label encoding to method Adding versionaddeds Co-authored-by: Jeff Reback <[email protected]>
1 parent 00e10a5 commit fd151ba

File tree

4 files changed

+262
-10
lines changed

4 files changed

+262
-10
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Other enhancements
104104
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)
105105
- :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` now support the argument ``skipna`` (:issue:`34047`)
106106
- :meth:`read_table` now supports the argument ``storage_options`` (:issue:`39167`)
107+
- :meth:`DataFrame.to_stata` and :meth:`StataWriter` now accept the keyword only argument ``value_labels`` to save labels for non-categorical columns
107108
- Methods that relied on hashmap based algos such as :meth:`DataFrameGroupBy.value_counts`, :meth:`DataFrameGroupBy.count` and :func:`factorize` ignored imaginary component for complex numbers (:issue:`17927`)
108109
- Add :meth:`Series.str.removeprefix` and :meth:`Series.str.removesuffix` introduced in Python 3.9 to remove pre-/suffixes from string-type :class:`Series` (:issue:`36944`)
109110

pandas/core/frame.py

+10
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,8 @@ def to_stata(
23932393
convert_strl: Sequence[Hashable] | None = None,
23942394
compression: CompressionOptions = "infer",
23952395
storage_options: StorageOptions = None,
2396+
*,
2397+
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
23962398
) -> None:
23972399
"""
23982400
Export DataFrame object to Stata dta format.
@@ -2474,6 +2476,13 @@ def to_stata(
24742476
24752477
.. versionadded:: 1.2.0
24762478
2479+
value_labels : dict of dicts
2480+
Dictionary containing columns as keys and dictionaries of column value
2481+
to labels as values. Labels for a single variable must be 32,000
2482+
characters or smaller.
2483+
2484+
.. versionadded:: 1.4.0
2485+
24772486
Raises
24782487
------
24792488
NotImplementedError
@@ -2535,6 +2544,7 @@ def to_stata(
25352544
variable_labels=variable_labels,
25362545
compression=compression,
25372546
storage_options=storage_options,
2547+
value_labels=value_labels,
25382548
**kwargs,
25392549
)
25402550
writer.write_file()

pandas/io/stata.py

+137-10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import struct
1919
import sys
2020
from typing import (
21+
TYPE_CHECKING,
2122
Any,
2223
AnyStr,
2324
Hashable,
@@ -46,6 +47,7 @@
4647
ensure_object,
4748
is_categorical_dtype,
4849
is_datetime64_dtype,
50+
is_numeric_dtype,
4951
)
5052

5153
from pandas import (
@@ -64,6 +66,9 @@
6466

6567
from pandas.io.common import get_handle
6668

69+
if TYPE_CHECKING:
70+
from typing import Literal
71+
6772
_version_error = (
6873
"Version of given Stata file is {version}. pandas supports importing "
6974
"versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
@@ -658,24 +663,37 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
658663
self.labname = catarray.name
659664
self._encoding = encoding
660665
categories = catarray.cat.categories
661-
self.value_labels = list(zip(np.arange(len(categories)), categories))
666+
self.value_labels: list[tuple[int | float, str]] = list(
667+
zip(np.arange(len(categories)), categories)
668+
)
662669
self.value_labels.sort(key=lambda x: x[0])
670+
671+
self._prepare_value_labels()
672+
673+
def _prepare_value_labels(self):
674+
"""Encode value labels."""
675+
663676
self.text_len = 0
664677
self.txt: list[bytes] = []
665678
self.n = 0
679+
# Offsets (length of categories), converted to int32
680+
self.off = np.array([])
681+
# Values, converted to int32
682+
self.val = np.array([])
683+
self.len = 0
666684

667685
# Compute lengths and setup lists of offsets and labels
668686
offsets: list[int] = []
669-
values: list[int] = []
687+
values: list[int | float] = []
670688
for vl in self.value_labels:
671-
category = vl[1]
689+
category: str | bytes = vl[1]
672690
if not isinstance(category, str):
673691
category = str(category)
674692
warnings.warn(
675-
value_label_mismatch_doc.format(catarray.name),
693+
value_label_mismatch_doc.format(self.labname),
676694
ValueLabelTypeMismatch,
677695
)
678-
category = category.encode(encoding)
696+
category = category.encode(self._encoding)
679697
offsets.append(self.text_len)
680698
self.text_len += len(category) + 1 # +1 for the padding
681699
values.append(vl[0])
@@ -748,6 +766,38 @@ def generate_value_label(self, byteorder: str) -> bytes:
748766
return bio.getvalue()
749767

750768

769+
class StataNonCatValueLabel(StataValueLabel):
770+
"""
771+
Prepare formatted version of value labels
772+
773+
Parameters
774+
----------
775+
labname : str
776+
Value label name
777+
value_labels: Dictionary
778+
Mapping of values to labels
779+
encoding : {"latin-1", "utf-8"}
780+
Encoding to use for value labels.
781+
"""
782+
783+
def __init__(
784+
self,
785+
labname: str,
786+
value_labels: dict[float | int, str],
787+
encoding: Literal["latin-1", "utf-8"] = "latin-1",
788+
):
789+
790+
if encoding not in ("latin-1", "utf-8"):
791+
raise ValueError("Only latin-1 and utf-8 are supported.")
792+
793+
self.labname = labname
794+
self._encoding = encoding
795+
self.value_labels: list[tuple[int | float, str]] = sorted(
796+
value_labels.items(), key=lambda x: x[0]
797+
)
798+
self._prepare_value_labels()
799+
800+
751801
class StataMissingValue:
752802
"""
753803
An observation's missing value.
@@ -2175,6 +2225,13 @@ class StataWriter(StataParser):
21752225
21762226
.. versionadded:: 1.2.0
21772227
2228+
value_labels : dict of dicts
2229+
Dictionary containing columns as keys and dictionaries of column value
2230+
to labels as values. The combined length of all labels for a single
2231+
variable must be 32,000 characters or smaller.
2232+
2233+
.. versionadded:: 1.4.0
2234+
21782235
Returns
21792236
-------
21802237
writer : StataWriter instance
@@ -2225,15 +2282,22 @@ def __init__(
22252282
variable_labels: dict[Hashable, str] | None = None,
22262283
compression: CompressionOptions = "infer",
22272284
storage_options: StorageOptions = None,
2285+
*,
2286+
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
22282287
):
22292288
super().__init__()
2289+
self.data = data
22302290
self._convert_dates = {} if convert_dates is None else convert_dates
22312291
self._write_index = write_index
22322292
self._time_stamp = time_stamp
22332293
self._data_label = data_label
22342294
self._variable_labels = variable_labels
2295+
self._non_cat_value_labels = value_labels
2296+
self._value_labels: list[StataValueLabel] = []
2297+
self._has_value_labels = np.array([], dtype=bool)
22352298
self._compression = compression
22362299
self._output_file: Buffer | None = None
2300+
self._converted_names: dict[Hashable, str] = {}
22372301
# attach nobs, nvars, data, varlist, typlist
22382302
self._prepare_pandas(data)
22392303
self.storage_options = storage_options
@@ -2243,7 +2307,6 @@ def __init__(
22432307
self._byteorder = _set_endianness(byteorder)
22442308
self._fname = fname
22452309
self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
2246-
self._converted_names: dict[Hashable, str] = {}
22472310

22482311
def _write(self, to_write: str) -> None:
22492312
"""
@@ -2259,17 +2322,50 @@ def _write_bytes(self, value: bytes) -> None:
22592322
"""
22602323
self.handles.handle.write(value) # type: ignore[arg-type]
22612324

2325+
def _prepare_non_cat_value_labels(
2326+
self, data: DataFrame
2327+
) -> list[StataNonCatValueLabel]:
2328+
"""
2329+
Check for value labels provided for non-categorical columns. Value
2330+
labels
2331+
"""
2332+
non_cat_value_labels: list[StataNonCatValueLabel] = []
2333+
if self._non_cat_value_labels is None:
2334+
return non_cat_value_labels
2335+
2336+
for labname, labels in self._non_cat_value_labels.items():
2337+
if labname in self._converted_names:
2338+
colname = self._converted_names[labname]
2339+
elif labname in data.columns:
2340+
colname = str(labname)
2341+
else:
2342+
raise KeyError(
2343+
f"Can't create value labels for {labname}, it wasn't "
2344+
"found in the dataset."
2345+
)
2346+
2347+
if not is_numeric_dtype(data[colname].dtype):
2348+
# Labels should not be passed explicitly for categorical
2349+
# columns that will be converted to int
2350+
raise ValueError(
2351+
f"Can't create value labels for {labname}, value labels "
2352+
"can only be applied to numeric columns."
2353+
)
2354+
svl = StataNonCatValueLabel(colname, labels)
2355+
non_cat_value_labels.append(svl)
2356+
return non_cat_value_labels
2357+
22622358
def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
22632359
"""
22642360
Check for categorical columns, retain categorical information for
22652361
Stata file and convert categorical data to int
22662362
"""
22672363
is_cat = [is_categorical_dtype(data[col].dtype) for col in data]
2268-
self._is_col_cat = is_cat
2269-
self._value_labels: list[StataValueLabel] = []
22702364
if not any(is_cat):
22712365
return data
22722366

2367+
self._has_value_labels |= np.array(is_cat)
2368+
22732369
get_base_missing_value = StataMissingValue.get_base_missing_value
22742370
data_formatted = []
22752371
for col, col_is_cat in zip(data, is_cat):
@@ -2449,6 +2545,17 @@ def _prepare_pandas(self, data: DataFrame) -> None:
24492545
# Replace NaNs with Stata missing values
24502546
data = self._replace_nans(data)
24512547

2548+
# Set all columns to initially unlabelled
2549+
self._has_value_labels = np.repeat(False, data.shape[1])
2550+
2551+
# Create value labels for non-categorical data
2552+
non_cat_value_labels = self._prepare_non_cat_value_labels(data)
2553+
2554+
non_cat_columns = [svl.labname for svl in non_cat_value_labels]
2555+
has_non_cat_val_labels = data.columns.isin(non_cat_columns)
2556+
self._has_value_labels |= has_non_cat_val_labels
2557+
self._value_labels.extend(non_cat_value_labels)
2558+
24522559
# Convert categoricals to int data, and strip labels
24532560
data = self._prepare_categoricals(data)
24542561

@@ -2688,7 +2795,7 @@ def _write_value_label_names(self) -> None:
26882795
# lbllist, 33*nvar, char array
26892796
for i in range(self.nvar):
26902797
# Use variable name when categorical
2691-
if self._is_col_cat[i]:
2798+
if self._has_value_labels[i]:
26922799
name = self.varlist[i]
26932800
name = self._null_terminate_str(name)
26942801
name = _pad_bytes(name[:32], 33)
@@ -3059,6 +3166,13 @@ class StataWriter117(StataWriter):
30593166
30603167
.. versionadded:: 1.1.0
30613168
3169+
value_labels : dict of dicts
3170+
Dictionary containing columns as keys and dictionaries of column value
3171+
to labels as values. The combined length of all labels for a single
3172+
variable must be 32,000 characters or smaller.
3173+
3174+
.. versionadded:: 1.4.0
3175+
30623176
Returns
30633177
-------
30643178
writer : StataWriter117 instance
@@ -3112,6 +3226,8 @@ def __init__(
31123226
convert_strl: Sequence[Hashable] | None = None,
31133227
compression: CompressionOptions = "infer",
31143228
storage_options: StorageOptions = None,
3229+
*,
3230+
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
31153231
):
31163232
# Copy to new list since convert_strl might be modified later
31173233
self._convert_strl: list[Hashable] = []
@@ -3127,6 +3243,7 @@ def __init__(
31273243
time_stamp=time_stamp,
31283244
data_label=data_label,
31293245
variable_labels=variable_labels,
3246+
value_labels=value_labels,
31303247
compression=compression,
31313248
storage_options=storage_options,
31323249
)
@@ -3272,7 +3389,7 @@ def _write_value_label_names(self) -> None:
32723389
for i in range(self.nvar):
32733390
# Use variable name when categorical
32743391
name = "" # default name
3275-
if self._is_col_cat[i]:
3392+
if self._has_value_labels[i]:
32763393
name = self.varlist[i]
32773394
name = self._null_terminate_str(name)
32783395
encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
@@ -3449,6 +3566,13 @@ class StataWriterUTF8(StataWriter117):
34493566
34503567
.. versionadded:: 1.1.0
34513568
3569+
value_labels : dict of dicts
3570+
Dictionary containing columns as keys and dictionaries of column value
3571+
to labels as values. The combined length of all labels for a single
3572+
variable must be 32,000 characters or smaller.
3573+
3574+
.. versionadded:: 1.4.0
3575+
34523576
Returns
34533577
-------
34543578
StataWriterUTF8
@@ -3505,6 +3629,8 @@ def __init__(
35053629
version: int | None = None,
35063630
compression: CompressionOptions = "infer",
35073631
storage_options: StorageOptions = None,
3632+
*,
3633+
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
35083634
):
35093635
if version is None:
35103636
version = 118 if data.shape[1] <= 32767 else 119
@@ -3525,6 +3651,7 @@ def __init__(
35253651
time_stamp=time_stamp,
35263652
data_label=data_label,
35273653
variable_labels=variable_labels,
3654+
value_labels=value_labels,
35283655
convert_strl=convert_strl,
35293656
compression=compression,
35303657
storage_options=storage_options,

0 commit comments

Comments
 (0)