15
15
import struct
16
16
from dateutil .relativedelta import relativedelta
17
17
from pandas .core .base import StringMixin
18
+ from pandas .core .categorical import Categorical
18
19
from pandas .core .frame import DataFrame
19
20
from pandas .core .series import Series
20
- from pandas .core .categorical import Categorical
21
21
import datetime
22
22
from pandas import compat , to_timedelta , to_datetime , isnull , DatetimeIndex
23
23
from pandas .compat import lrange , lmap , lzip , text_type , string_types , range , \
24
- zip
24
+ zip , BytesIO
25
25
import pandas .core .common as com
26
26
from pandas .io .common import get_filepath_or_buffer
27
27
from pandas .lib import max_len_string_array , infer_dtype
@@ -336,6 +336,15 @@ class PossiblePrecisionLoss(Warning):
336
336
conversion range. This may result in a loss of precision in the saved data.
337
337
"""
338
338
339
+ class ValueLabelTypeMismatch (Warning ):
340
+ pass
341
+
342
+ value_label_mismatch_doc = """
343
+ Stata value labels (pandas categories) must be strings. Column {0} contains
344
+ non-string labels which will be converted to strings. Please check that the
345
+ Stata data file created has not lost information due to duplicate labels.
346
+ """
347
+
339
348
340
349
class InvalidColumnName (Warning ):
341
350
pass
@@ -425,6 +434,131 @@ def _cast_to_stata_types(data):
425
434
return data
426
435
427
436
437
+ class StataValueLabel (object ):
438
+ """
439
+ Parse a categorical column and prepare formatted output
440
+
441
+ Parameters
442
+ -----------
443
+ value : int8, int16, int32, float32 or float64
444
+ The Stata missing value code
445
+
446
+ Attributes
447
+ ----------
448
+ string : string
449
+ String representation of the Stata missing value
450
+ value : int8, int16, int32, float32 or float64
451
+ The original encoded missing value
452
+
453
+ Methods
454
+ -------
455
+ generate_value_label
456
+
457
+ """
458
+
459
+ def __init__ (self , catarray ):
460
+
461
+ self .labname = catarray .name
462
+
463
+ categories = catarray .cat .categories
464
+ self .value_labels = list (zip (np .arange (len (categories )), categories ))
465
+ self .value_labels .sort (key = lambda x : x [0 ])
466
+ self .text_len = np .int32 (0 )
467
+ self .off = []
468
+ self .val = []
469
+ self .txt = []
470
+ self .n = 0
471
+
472
+ # Compute lengths and setup lists of offsets and labels
473
+ for vl in self .value_labels :
474
+ category = vl [1 ]
475
+ if not isinstance (category , string_types ):
476
+ category = str (category )
477
+ import warnings
478
+ warnings .warn (value_label_mismatch_doc .format (catarray .name ),
479
+ ValueLabelTypeMismatch )
480
+
481
+ self .off .append (self .text_len )
482
+ self .text_len += len (category ) + 1 # +1 for the padding
483
+ self .val .append (vl [0 ])
484
+ self .txt .append (category )
485
+ self .n += 1
486
+
487
+ if self .text_len > 32000 :
488
+ raise ValueError ('Stata value labels for a single variable must '
489
+ 'have a combined length less than 32,000 '
490
+ 'characters.' )
491
+
492
+ # Ensure int32
493
+ self .off = np .array (self .off , dtype = np .int32 )
494
+ self .val = np .array (self .val , dtype = np .int32 )
495
+
496
+ # Total length
497
+ self .len = 4 + 4 + 4 * self .n + 4 * self .n + self .text_len
498
+
499
+ def _encode (self , s ):
500
+ """
501
+ Python 3 compatability shim
502
+ """
503
+ if compat .PY3 :
504
+ return s .encode (self ._encoding )
505
+ else :
506
+ return s
507
+
508
+ def generate_value_label (self , byteorder , encoding ):
509
+ """
510
+ Parameters
511
+ ----------
512
+ byteorder : str
513
+ Byte order of the output
514
+ encoding : str
515
+ File encoding
516
+
517
+ Returns
518
+ -------
519
+ value_label : bytes
520
+ Bytes containing the formatted value label
521
+ """
522
+
523
+ self ._encoding = encoding
524
+ bio = BytesIO ()
525
+ null_string = '\x00 '
526
+ null_byte = b'\x00 '
527
+
528
+ # len
529
+ bio .write (struct .pack (byteorder + 'i' , self .len ))
530
+
531
+ # labname
532
+ labname = self ._encode (_pad_bytes (self .labname [:32 ], 33 ))
533
+ bio .write (labname )
534
+
535
+ # padding - 3 bytes
536
+ for i in range (3 ):
537
+ bio .write (struct .pack ('c' , null_byte ))
538
+
539
+ # value_label_table
540
+ # n - int32
541
+ bio .write (struct .pack (byteorder + 'i' , self .n ))
542
+
543
+ # textlen - int32
544
+ bio .write (struct .pack (byteorder + 'i' , self .text_len ))
545
+
546
+ # off - int32 array (n elements)
547
+ for offset in self .off :
548
+ bio .write (struct .pack (byteorder + 'i' , offset ))
549
+
550
+ # val - int32 array (n elements)
551
+ for value in self .val :
552
+ bio .write (struct .pack (byteorder + 'i' , value ))
553
+
554
+ # txt - Text labels, null terminated
555
+ for text in self .txt :
556
+ bio .write (self ._encode (text + null_string ))
557
+
558
+ bio .seek (0 )
559
+ return bio .read ()
560
+
561
+
428
562
class StataMissingValue (StringMixin ):
429
563
"""
430
564
An observation's missing value.
@@ -477,25 +611,31 @@ class StataMissingValue(StringMixin):
477
611
for i in range (1 , 27 ):
478
612
MISSING_VALUES [i + b ] = '.' + chr (96 + i )
479
613
480
- base = b'\x00 \x00 \x00 \x7f '
614
+ float32_base = b'\x00 \x00 \x00 \x7f '
481
615
increment = struct .unpack ('<i' , b'\x00 \x08 \x00 \x00 ' )[0 ]
482
616
for i in range (27 ):
483
- value = struct .unpack ('<f' , base )[0 ]
617
+ value = struct .unpack ('<f' , float32_base )[0 ]
484
618
MISSING_VALUES [value ] = '.'
485
619
if i > 0 :
486
620
MISSING_VALUES [value ] += chr (96 + i )
487
621
int_value = struct .unpack ('<i' , struct .pack ('<f' , value ))[0 ] + increment
488
- base = struct .pack ('<i' , int_value )
622
+ float32_base = struct .pack ('<i' , int_value )
489
623
490
- base = b'\x00 \x00 \x00 \x00 \x00 \x00 \xe0 \x7f '
624
+ float64_base = b'\x00 \x00 \x00 \x00 \x00 \x00 \xe0 \x7f '
491
625
increment = struct .unpack ('q' , b'\x00 \x00 \x00 \x00 \x00 \x01 \x00 \x00 ' )[0 ]
492
626
for i in range (27 ):
493
- value = struct .unpack ('<d' , base )[0 ]
627
+ value = struct .unpack ('<d' , float64_base )[0 ]
494
628
MISSING_VALUES [value ] = '.'
495
629
if i > 0 :
496
630
MISSING_VALUES [value ] += chr (96 + i )
497
631
int_value = struct .unpack ('q' , struct .pack ('<d' , value ))[0 ] + increment
498
- base = struct .pack ('q' , int_value )
632
+ float64_base = struct .pack ('q' , int_value )
633
+
634
+ BASE_MISSING_VALUES = {'int8' : 101 ,
635
+ 'int16' : 32741 ,
636
+ 'int32' : 2147483621 ,
637
+ 'float32' : struct .unpack ('<f' , float32_base )[0 ],
638
+ 'float64' : struct .unpack ('<d' , float64_base )[0 ]}
499
639
500
640
def __init__ (self , value ):
501
641
self ._value = value
@@ -518,6 +658,22 @@ def __eq__(self, other):
518
658
return (isinstance (other , self .__class__ )
519
659
and self .string == other .string and self .value == other .value )
520
660
661
+ @classmethod
662
+ def get_base_missing_value (cls , dtype ):
663
+ if dtype == np .int8 :
664
+ value = cls .BASE_MISSING_VALUES ['int8' ]
665
+ elif dtype == np .int16 :
666
+ value = cls .BASE_MISSING_VALUES ['int16' ]
667
+ elif dtype == np .int32 :
668
+ value = cls .BASE_MISSING_VALUES ['int32' ]
669
+ elif dtype == np .float32 :
670
+ value = cls .BASE_MISSING_VALUES ['float32' ]
671
+ elif dtype == np .float64 :
672
+ value = cls .BASE_MISSING_VALUES ['float64' ]
673
+ else :
674
+ raise ValueError ('Unsupported dtype' )
675
+ return value
676
+
521
677
522
678
class StataParser (object ):
523
679
_default_encoding = 'cp1252'
@@ -1111,10 +1267,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
1111
1267
umissing , umissing_loc = np .unique (series [missing ],
1112
1268
return_inverse = True )
1113
1269
replacement = Series (series , dtype = np .object )
1114
- for i , um in enumerate (umissing ):
1270
+ for j , um in enumerate (umissing ):
1115
1271
missing_value = StataMissingValue (um )
1116
1272
1117
- loc = missing_loc [umissing_loc == i ]
1273
+ loc = missing_loc [umissing_loc == j ]
1118
1274
replacement .iloc [loc ] = missing_value
1119
1275
else : # All replacements are identical
1120
1276
dtype = series .dtype
@@ -1390,6 +1546,45 @@ def _write(self, to_write):
1390
1546
else :
1391
1547
self ._file .write (to_write )
1392
1548
1549
+ def _prepare_categoricals (self , data ):
1550
+ """Check for categorigal columns, retain categorical information for
1551
+ Stata file and convert categorical data to int"""
1552
+
1553
+ is_cat = [com .is_categorical_dtype (data [col ]) for col in data ]
1554
+ self ._is_col_cat = is_cat
1555
+ self ._value_labels = []
1556
+ if not any (is_cat ):
1557
+ return data
1558
+
1559
+ get_base_missing_value = StataMissingValue .get_base_missing_value
1560
+ index = data .index
1561
+ data_formatted = []
1562
+ for col , col_is_cat in zip (data , is_cat ):
1563
+ if col_is_cat :
1564
+ self ._value_labels .append (StataValueLabel (data [col ]))
1565
+ dtype = data [col ].cat .codes .dtype
1566
+ if dtype == np .int64 :
1567
+ raise ValueError ('It is not possible to export int64-based '
1568
+ 'categorical data to Stata.' )
1569
+ values = data [col ].cat .codes .values .copy ()
1570
+
1571
+ # Upcast if needed so that correct missing values can be set
1572
+ if values .max () >= get_base_missing_value (dtype ):
1573
+ if dtype == np .int8 :
1574
+ dtype = np .int16
1575
+ elif dtype == np .int16 :
1576
+ dtype = np .int32
1577
+ else :
1578
+ dtype = np .float64
1579
+ values = np .array (values , dtype = dtype )
1580
+
1581
+ # Replace missing values with Stata missing value for type
1582
+ values [values == - 1 ] = get_base_missing_value (dtype )
1583
+ data_formatted .append ((col , values , index ))
1584
+
1585
+ else :
1586
+ data_formatted .append ((col , data [col ]))
1587
+ return DataFrame .from_items (data_formatted )
1393
1588
1394
1589
def _replace_nans (self , data ):
1395
1590
# return data
@@ -1480,27 +1675,26 @@ def _check_column_names(self, data):
1480
1675
def _prepare_pandas (self , data ):
1481
1676
#NOTE: we might need a different API / class for pandas objects so
1482
1677
# we can set different semantics - handle this with a PR to pandas.io
1483
- class DataFrameRowIter (object ):
1484
- def __init__ (self , data ):
1485
- self .data = data
1486
-
1487
- def __iter__ (self ):
1488
- for row in data .itertuples ():
1489
- # First element is index, so remove
1490
- yield row [1 :]
1491
1678
1492
1679
if self ._write_index :
1493
1680
data = data .reset_index ()
1494
- # Check columns for compatibility with stata
1495
- data = _cast_to_stata_types (data )
1681
+
1496
1682
# Ensure column names are strings
1497
1683
data = self ._check_column_names (data )
1684
+
1685
+ # Check columns for compatibility with stata, upcast if necessary
1686
+ data = _cast_to_stata_types (data )
1687
+
1498
1688
# Replace NaNs with Stata missing values
1499
1689
data = self ._replace_nans (data )
1500
- self .datarows = DataFrameRowIter (data )
1690
+
1691
+ # Convert categoricals to int data, and strip labels
1692
+ data = self ._prepare_categoricals (data )
1693
+
1501
1694
self .nobs , self .nvar = data .shape
1502
1695
self .data = data
1503
1696
self .varlist = data .columns .tolist ()
1697
+
1504
1698
dtypes = data .dtypes
1505
1699
if self ._convert_dates is not None :
1506
1700
self ._convert_dates = _maybe_convert_to_int_keys (
@@ -1515,6 +1709,7 @@ def __iter__(self):
1515
1709
self .fmtlist = []
1516
1710
for col , dtype in dtypes .iteritems ():
1517
1711
self .fmtlist .append (_dtype_to_default_stata_fmt (dtype , data [col ]))
1712
+
1518
1713
# set the given format for the datetime cols
1519
1714
if self ._convert_dates is not None :
1520
1715
for key in self ._convert_dates :
@@ -1529,8 +1724,14 @@ def write_file(self):
1529
1724
self ._write (_pad_bytes ("" , 5 ))
1530
1725
self ._prepare_data ()
1531
1726
self ._write_data ()
1727
+ self ._write_value_labels ()
1532
1728
self ._file .close ()
1533
1729
1730
+ def _write_value_labels (self ):
1731
+ for vl in self ._value_labels :
1732
+ self ._file .write (vl .generate_value_label (self ._byteorder ,
1733
+ self ._encoding ))
1734
+
1534
1735
def _write_header (self , data_label = None , time_stamp = None ):
1535
1736
byteorder = self ._byteorder
1536
1737
# ds_format - just use 114
@@ -1585,9 +1786,15 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
1585
1786
self ._write (_pad_bytes (fmt , 49 ))
1586
1787
1587
1788
# lbllist, 33*nvar, char array
1588
- #NOTE: this is where you could get fancy with pandas categorical type
1589
1789
for i in range (nvar ):
1590
- self ._write (_pad_bytes ("" , 33 ))
1790
+ # Use variable name when categorical
1791
+ if self ._is_col_cat [i ]:
1792
+ name = self .varlist [i ]
1793
+ name = self ._null_terminate (name , True )
1794
+ name = _pad_bytes (name [:32 ], 33 )
1795
+ self ._write (name )
1796
+ else : # Default is empty label
1797
+ self ._write (_pad_bytes ("" , 33 ))
1591
1798
1592
1799
def _write_variable_labels (self , labels = None ):
1593
1800
nvar = self .nvar
@@ -1624,9 +1831,6 @@ def _prepare_data(self):
1624
1831
data_cols .append (data [col ].values )
1625
1832
dtype = np .dtype (dtype )
1626
1833
1627
- # 3. Convert to record array
1628
-
1629
- # data.to_records(index=False, convert_datetime64=False)
1630
1834
if has_strings :
1631
1835
self .data = np .fromiter (zip (* data_cols ), dtype = dtype )
1632
1836
else :
0 commit comments