7
7
import nose
8
8
import numpy as np
9
9
10
- from pandas import DataFrame , Series
10
+ from pandas import DataFrame , Series , MultiIndex
11
11
from pandas .compat import range , lrange , iteritems
12
12
#from pandas.core.datetools import format as date_format
13
13
@@ -266,7 +266,7 @@ def _roundtrip(self):
266
266
self .pandasSQL .to_sql (self .test_frame1 , 'test_frame_roundtrip' )
267
267
result = self .pandasSQL .read_sql ('SELECT * FROM test_frame_roundtrip' )
268
268
269
- result .set_index ('pandas_index ' , inplace = True )
269
+ result .set_index ('level_0 ' , inplace = True )
270
270
# result.index.astype(int)
271
271
272
272
result .index .name = None
@@ -391,7 +391,7 @@ def test_roundtrip(self):
391
391
392
392
# HACK!
393
393
result .index = self .test_frame1 .index
394
- result .set_index ('pandas_index ' , inplace = True )
394
+ result .set_index ('level_0 ' , inplace = True )
395
395
result .index .astype (int )
396
396
result .index .name = None
397
397
tm .assert_frame_equal (result , self .test_frame1 )
@@ -476,10 +476,10 @@ def connect(self):
476
476
def test_to_sql_index_label (self ):
477
477
temp_frame = DataFrame ({'col1' : range (4 )})
478
478
479
- # no index name, defaults to 'pandas_index '
479
+ # no index name, defaults to 'index '
480
480
sql .to_sql (temp_frame , 'test_index_label' , self .conn )
481
481
frame = sql .read_table ('test_index_label' , self .conn )
482
- self .assertEqual (frame .columns [0 ], 'pandas_index ' )
482
+ self .assertEqual (frame .columns [0 ], 'index ' )
483
483
484
484
# specifying index_label
485
485
sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
@@ -489,11 +489,11 @@ def test_to_sql_index_label(self):
489
489
"Specified index_label not written to database" )
490
490
491
491
# using the index name
492
- temp_frame .index .name = 'index '
492
+ temp_frame .index .name = 'index_name '
493
493
sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
494
494
if_exists = 'replace' )
495
495
frame = sql .read_table ('test_index_label' , self .conn )
496
- self .assertEqual (frame .columns [0 ], 'index ' ,
496
+ self .assertEqual (frame .columns [0 ], 'index_name ' ,
497
497
"Index name not written to database" )
498
498
499
499
# has index name, but specifying index_label
@@ -503,6 +503,43 @@ def test_to_sql_index_label(self):
503
503
self .assertEqual (frame .columns [0 ], 'other_label' ,
504
504
"Specified index_label not written to database" )
505
505
506
+ def test_to_sql_index_label_multiindex (self ):
507
+ temp_frame = DataFrame ({'col1' : range (4 )},
508
+ index = MultiIndex .from_product ([('A0' , 'A1' ), ('B0' , 'B1' )]))
509
+
510
+ # no index name, defaults to 'level_0' and 'level_1'
511
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn )
512
+ frame = sql .read_table ('test_index_label' , self .conn )
513
+ self .assertEqual (frame .columns [0 ], 'level_0' )
514
+ self .assertEqual (frame .columns [1 ], 'level_1' )
515
+
516
+ # specifying index_label
517
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
518
+ if_exists = 'replace' , index_label = ['A' , 'B' ])
519
+ frame = sql .read_table ('test_index_label' , self .conn )
520
+ self .assertEqual (frame .columns [:2 ].tolist (), ['A' , 'B' ],
521
+ "Specified index_labels not written to database" )
522
+
523
+ # using the index name
524
+ temp_frame .index .names = ['A' , 'B' ]
525
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
526
+ if_exists = 'replace' )
527
+ frame = sql .read_table ('test_index_label' , self .conn )
528
+ self .assertEqual (frame .columns [:2 ].tolist (), ['A' , 'B' ],
529
+ "Index names not written to database" )
530
+
531
+ # has index name, but specifying index_label
532
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
533
+ if_exists = 'replace' , index_label = ['C' , 'D' ])
534
+ frame = sql .read_table ('test_index_label' , self .conn )
535
+ self .assertEqual (frame .columns [:2 ].tolist (), ['C' , 'D' ],
536
+ "Specified index_labels not written to database" )
537
+
538
+ # wrong length of index_label
539
+ self .assertRaises (ValueError , sql .to_sql , temp_frame ,
540
+ 'test_index_label' , self .conn , if_exists = 'replace' ,
541
+ index_label = 'C' )
542
+
506
543
def test_read_table_columns (self ):
507
544
# test columns argument in read_table
508
545
sql .to_sql (self .test_frame1 , 'test_frame' , self .conn )
@@ -566,6 +603,23 @@ def test_sql_open_close(self):
566
603
567
604
tm .assert_frame_equal (self .test_frame2 , result )
568
605
606
+ def test_roundtrip (self ):
607
+ # this test otherwise fails, Legacy mode still uses 'pandas_index'
608
+ # as default index column label
609
+ sql .to_sql (self .test_frame1 , 'test_frame_roundtrip' ,
610
+ con = self .conn , flavor = 'sqlite' )
611
+ result = sql .read_sql (
612
+ 'SELECT * FROM test_frame_roundtrip' ,
613
+ con = self .conn ,
614
+ flavor = 'sqlite' )
615
+
616
+ # HACK!
617
+ result .index = self .test_frame1 .index
618
+ result .set_index ('pandas_index' , inplace = True )
619
+ result .index .astype (int )
620
+ result .index .name = None
621
+ tm .assert_frame_equal (result , self .test_frame1 )
622
+
569
623
570
624
class _TestSQLAlchemy (PandasSQLTest ):
571
625
"""
@@ -788,6 +842,16 @@ def setUp(self):
788
842
789
843
self ._load_test1_data ()
790
844
845
+ def _roundtrip (self ):
846
+ # overwrite parent function (level_0 -> pandas_index in legacy mode)
847
+ self .drop_table ('test_frame_roundtrip' )
848
+ self .pandasSQL .to_sql (self .test_frame1 , 'test_frame_roundtrip' )
849
+ result = self .pandasSQL .read_sql ('SELECT * FROM test_frame_roundtrip' )
850
+ result .set_index ('pandas_index' , inplace = True )
851
+ result .index .name = None
852
+
853
+ tm .assert_frame_equal (result , self .test_frame1 )
854
+
791
855
def test_invalid_flavor (self ):
792
856
self .assertRaises (
793
857
NotImplementedError , sql .PandasSQLLegacy , self .conn , 'oracle' )
0 commit comments