18
18
import struct
19
19
import sys
20
20
from typing import (
21
+ TYPE_CHECKING ,
21
22
Any ,
22
23
AnyStr ,
23
24
Hashable ,
46
47
ensure_object ,
47
48
is_categorical_dtype ,
48
49
is_datetime64_dtype ,
50
+ is_numeric_dtype ,
49
51
)
50
52
51
53
from pandas import (
64
66
65
67
from pandas .io .common import get_handle
66
68
69
+ if TYPE_CHECKING :
70
+ from typing import Literal
71
+
67
72
_version_error = (
68
73
"Version of given Stata file is {version}. pandas supports importing "
69
74
"versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
@@ -658,24 +663,37 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
658
663
self .labname = catarray .name
659
664
self ._encoding = encoding
660
665
categories = catarray .cat .categories
661
- self .value_labels = list (zip (np .arange (len (categories )), categories ))
666
+ self .value_labels : list [tuple [int | float , str ]] = list (
667
+ zip (np .arange (len (categories )), categories )
668
+ )
662
669
self .value_labels .sort (key = lambda x : x [0 ])
670
+
671
+ self ._prepare_value_labels ()
672
+
673
+ def _prepare_value_labels (self ):
674
+ """Encode value labels."""
675
+
663
676
self .text_len = 0
664
677
self .txt : list [bytes ] = []
665
678
self .n = 0
679
+ # Offsets (length of categories), converted to int32
680
+ self .off = np .array ([])
681
+ # Values, converted to int32
682
+ self .val = np .array ([])
683
+ self .len = 0
666
684
667
685
# Compute lengths and setup lists of offsets and labels
668
686
offsets : list [int ] = []
669
- values : list [int ] = []
687
+ values : list [int | float ] = []
670
688
for vl in self .value_labels :
671
- category = vl [1 ]
689
+ category : str | bytes = vl [1 ]
672
690
if not isinstance (category , str ):
673
691
category = str (category )
674
692
warnings .warn (
675
- value_label_mismatch_doc .format (catarray . name ),
693
+ value_label_mismatch_doc .format (self . labname ),
676
694
ValueLabelTypeMismatch ,
677
695
)
678
- category = category .encode (encoding )
696
+ category = category .encode (self . _encoding )
679
697
offsets .append (self .text_len )
680
698
self .text_len += len (category ) + 1 # +1 for the padding
681
699
values .append (vl [0 ])
@@ -748,6 +766,38 @@ def generate_value_label(self, byteorder: str) -> bytes:
748
766
return bio .getvalue ()
749
767
750
768
769
+ class StataNonCatValueLabel (StataValueLabel ):
770
+ """
771
+ Prepare formatted version of value labels
772
+
773
+ Parameters
774
+ ----------
775
+ labname : str
776
+ Value label name
777
+ value_labels: Dictionary
778
+ Mapping of values to labels
779
+ encoding : {"latin-1", "utf-8"}
780
+ Encoding to use for value labels.
781
+ """
782
+
783
+ def __init__ (
784
+ self ,
785
+ labname : str ,
786
+ value_labels : dict [float | int , str ],
787
+ encoding : Literal ["latin-1" , "utf-8" ] = "latin-1" ,
788
+ ):
789
+
790
+ if encoding not in ("latin-1" , "utf-8" ):
791
+ raise ValueError ("Only latin-1 and utf-8 are supported." )
792
+
793
+ self .labname = labname
794
+ self ._encoding = encoding
795
+ self .value_labels : list [tuple [int | float , str ]] = sorted (
796
+ value_labels .items (), key = lambda x : x [0 ]
797
+ )
798
+ self ._prepare_value_labels ()
799
+
800
+
751
801
class StataMissingValue :
752
802
"""
753
803
An observation's missing value.
@@ -2175,6 +2225,13 @@ class StataWriter(StataParser):
2175
2225
2176
2226
.. versionadded:: 1.2.0
2177
2227
2228
+ value_labels : dict of dicts
2229
+ Dictionary containing columns as keys and dictionaries of column value
2230
+ to labels as values. The combined length of all labels for a single
2231
+ variable must be 32,000 characters or smaller.
2232
+
2233
+ .. versionadded:: 1.4.0
2234
+
2178
2235
Returns
2179
2236
-------
2180
2237
writer : StataWriter instance
@@ -2225,15 +2282,22 @@ def __init__(
2225
2282
variable_labels : dict [Hashable , str ] | None = None ,
2226
2283
compression : CompressionOptions = "infer" ,
2227
2284
storage_options : StorageOptions = None ,
2285
+ * ,
2286
+ value_labels : dict [Hashable , dict [float | int , str ]] | None = None ,
2228
2287
):
2229
2288
super ().__init__ ()
2289
+ self .data = data
2230
2290
self ._convert_dates = {} if convert_dates is None else convert_dates
2231
2291
self ._write_index = write_index
2232
2292
self ._time_stamp = time_stamp
2233
2293
self ._data_label = data_label
2234
2294
self ._variable_labels = variable_labels
2295
+ self ._non_cat_value_labels = value_labels
2296
+ self ._value_labels : list [StataValueLabel ] = []
2297
+ self ._has_value_labels = np .array ([], dtype = bool )
2235
2298
self ._compression = compression
2236
2299
self ._output_file : Buffer | None = None
2300
+ self ._converted_names : dict [Hashable , str ] = {}
2237
2301
# attach nobs, nvars, data, varlist, typlist
2238
2302
self ._prepare_pandas (data )
2239
2303
self .storage_options = storage_options
@@ -2243,7 +2307,6 @@ def __init__(
2243
2307
self ._byteorder = _set_endianness (byteorder )
2244
2308
self ._fname = fname
2245
2309
self .type_converters = {253 : np .int32 , 252 : np .int16 , 251 : np .int8 }
2246
- self ._converted_names : dict [Hashable , str ] = {}
2247
2310
2248
2311
def _write (self , to_write : str ) -> None :
2249
2312
"""
@@ -2259,17 +2322,50 @@ def _write_bytes(self, value: bytes) -> None:
2259
2322
"""
2260
2323
self .handles .handle .write (value ) # type: ignore[arg-type]
2261
2324
2325
+ def _prepare_non_cat_value_labels (
2326
+ self , data : DataFrame
2327
+ ) -> list [StataNonCatValueLabel ]:
2328
+ """
2329
+ Check for value labels provided for non-categorical columns. Value
2330
+ labels
2331
+ """
2332
+ non_cat_value_labels : list [StataNonCatValueLabel ] = []
2333
+ if self ._non_cat_value_labels is None :
2334
+ return non_cat_value_labels
2335
+
2336
+ for labname , labels in self ._non_cat_value_labels .items ():
2337
+ if labname in self ._converted_names :
2338
+ colname = self ._converted_names [labname ]
2339
+ elif labname in data .columns :
2340
+ colname = str (labname )
2341
+ else :
2342
+ raise KeyError (
2343
+ f"Can't create value labels for { labname } , it wasn't "
2344
+ "found in the dataset."
2345
+ )
2346
+
2347
+ if not is_numeric_dtype (data [colname ].dtype ):
2348
+ # Labels should not be passed explicitly for categorical
2349
+ # columns that will be converted to int
2350
+ raise ValueError (
2351
+ f"Can't create value labels for { labname } , value labels "
2352
+ "can only be applied to numeric columns."
2353
+ )
2354
+ svl = StataNonCatValueLabel (colname , labels )
2355
+ non_cat_value_labels .append (svl )
2356
+ return non_cat_value_labels
2357
+
2262
2358
def _prepare_categoricals (self , data : DataFrame ) -> DataFrame :
2263
2359
"""
2264
2360
Check for categorical columns, retain categorical information for
2265
2361
Stata file and convert categorical data to int
2266
2362
"""
2267
2363
is_cat = [is_categorical_dtype (data [col ].dtype ) for col in data ]
2268
- self ._is_col_cat = is_cat
2269
- self ._value_labels : list [StataValueLabel ] = []
2270
2364
if not any (is_cat ):
2271
2365
return data
2272
2366
2367
+ self ._has_value_labels |= np .array (is_cat )
2368
+
2273
2369
get_base_missing_value = StataMissingValue .get_base_missing_value
2274
2370
data_formatted = []
2275
2371
for col , col_is_cat in zip (data , is_cat ):
@@ -2449,6 +2545,17 @@ def _prepare_pandas(self, data: DataFrame) -> None:
2449
2545
# Replace NaNs with Stata missing values
2450
2546
data = self ._replace_nans (data )
2451
2547
2548
+ # Set all columns to initially unlabelled
2549
+ self ._has_value_labels = np .repeat (False , data .shape [1 ])
2550
+
2551
+ # Create value labels for non-categorical data
2552
+ non_cat_value_labels = self ._prepare_non_cat_value_labels (data )
2553
+
2554
+ non_cat_columns = [svl .labname for svl in non_cat_value_labels ]
2555
+ has_non_cat_val_labels = data .columns .isin (non_cat_columns )
2556
+ self ._has_value_labels |= has_non_cat_val_labels
2557
+ self ._value_labels .extend (non_cat_value_labels )
2558
+
2452
2559
# Convert categoricals to int data, and strip labels
2453
2560
data = self ._prepare_categoricals (data )
2454
2561
@@ -2688,7 +2795,7 @@ def _write_value_label_names(self) -> None:
2688
2795
# lbllist, 33*nvar, char array
2689
2796
for i in range (self .nvar ):
2690
2797
# Use variable name when categorical
2691
- if self ._is_col_cat [i ]:
2798
+ if self ._has_value_labels [i ]:
2692
2799
name = self .varlist [i ]
2693
2800
name = self ._null_terminate_str (name )
2694
2801
name = _pad_bytes (name [:32 ], 33 )
@@ -3059,6 +3166,13 @@ class StataWriter117(StataWriter):
3059
3166
3060
3167
.. versionadded:: 1.1.0
3061
3168
3169
+ value_labels : dict of dicts
3170
+ Dictionary containing columns as keys and dictionaries of column value
3171
+ to labels as values. The combined length of all labels for a single
3172
+ variable must be 32,000 characters or smaller.
3173
+
3174
+ .. versionadded:: 1.4.0
3175
+
3062
3176
Returns
3063
3177
-------
3064
3178
writer : StataWriter117 instance
@@ -3112,6 +3226,8 @@ def __init__(
3112
3226
convert_strl : Sequence [Hashable ] | None = None ,
3113
3227
compression : CompressionOptions = "infer" ,
3114
3228
storage_options : StorageOptions = None ,
3229
+ * ,
3230
+ value_labels : dict [Hashable , dict [float | int , str ]] | None = None ,
3115
3231
):
3116
3232
# Copy to new list since convert_strl might be modified later
3117
3233
self ._convert_strl : list [Hashable ] = []
@@ -3127,6 +3243,7 @@ def __init__(
3127
3243
time_stamp = time_stamp ,
3128
3244
data_label = data_label ,
3129
3245
variable_labels = variable_labels ,
3246
+ value_labels = value_labels ,
3130
3247
compression = compression ,
3131
3248
storage_options = storage_options ,
3132
3249
)
@@ -3272,7 +3389,7 @@ def _write_value_label_names(self) -> None:
3272
3389
for i in range (self .nvar ):
3273
3390
# Use variable name when categorical
3274
3391
name = "" # default name
3275
- if self ._is_col_cat [i ]:
3392
+ if self ._has_value_labels [i ]:
3276
3393
name = self .varlist [i ]
3277
3394
name = self ._null_terminate_str (name )
3278
3395
encoded_name = _pad_bytes_new (name [:32 ].encode (self ._encoding ), vl_len + 1 )
@@ -3449,6 +3566,13 @@ class StataWriterUTF8(StataWriter117):
3449
3566
3450
3567
.. versionadded:: 1.1.0
3451
3568
3569
+ value_labels : dict of dicts
3570
+ Dictionary containing columns as keys and dictionaries of column value
3571
+ to labels as values. The combined length of all labels for a single
3572
+ variable must be 32,000 characters or smaller.
3573
+
3574
+ .. versionadded:: 1.4.0
3575
+
3452
3576
Returns
3453
3577
-------
3454
3578
StataWriterUTF8
@@ -3505,6 +3629,8 @@ def __init__(
3505
3629
version : int | None = None ,
3506
3630
compression : CompressionOptions = "infer" ,
3507
3631
storage_options : StorageOptions = None ,
3632
+ * ,
3633
+ value_labels : dict [Hashable , dict [float | int , str ]] | None = None ,
3508
3634
):
3509
3635
if version is None :
3510
3636
version = 118 if data .shape [1 ] <= 32767 else 119
@@ -3525,6 +3651,7 @@ def __init__(
3525
3651
time_stamp = time_stamp ,
3526
3652
data_label = data_label ,
3527
3653
variable_labels = variable_labels ,
3654
+ value_labels = value_labels ,
3528
3655
convert_strl = convert_strl ,
3529
3656
compression = compression ,
3530
3657
storage_options = storage_options ,
0 commit comments