7
7
import nose
8
8
import numpy as np
9
9
10
- from pandas import DataFrame , Series
10
+ from pandas import DataFrame , Series , MultiIndex
11
11
from pandas .compat import range , lrange , iteritems
12
12
#from pandas.core.datetools import format as date_format
13
13
@@ -266,7 +266,7 @@ def _roundtrip(self):
266
266
self .pandasSQL .to_sql (self .test_frame1 , 'test_frame_roundtrip' )
267
267
result = self .pandasSQL .read_sql ('SELECT * FROM test_frame_roundtrip' )
268
268
269
- result .set_index ('pandas_index ' , inplace = True )
269
+ result .set_index ('level_0 ' , inplace = True )
270
270
# result.index.astype(int)
271
271
272
272
result .index .name = None
@@ -391,7 +391,7 @@ def test_roundtrip(self):
391
391
392
392
# HACK!
393
393
result .index = self .test_frame1 .index
394
- result .set_index ('pandas_index ' , inplace = True )
394
+ result .set_index ('level_0 ' , inplace = True )
395
395
result .index .astype (int )
396
396
result .index .name = None
397
397
tm .assert_frame_equal (result , self .test_frame1 )
@@ -460,7 +460,9 @@ def test_date_and_index(self):
460
460
issubclass (df .IntDateCol .dtype .type , np .datetime64 ),
461
461
"IntDateCol loaded with incorrect type" )
462
462
463
+
463
464
class TestSQLApi (_TestSQLApi ):
465
+
464
466
"""Test the public API as it would be used directly
465
467
"""
466
468
flavor = 'sqlite'
@@ -474,10 +476,10 @@ def connect(self):
474
476
def test_to_sql_index_label (self ):
475
477
temp_frame = DataFrame ({'col1' : range (4 )})
476
478
477
- # no index name, defaults to 'pandas_index '
479
+ # no index name, defaults to 'index '
478
480
sql .to_sql (temp_frame , 'test_index_label' , self .conn )
479
481
frame = sql .read_table ('test_index_label' , self .conn )
480
- self .assertEqual (frame .columns [0 ], 'pandas_index ' )
482
+ self .assertEqual (frame .columns [0 ], 'index ' )
481
483
482
484
# specifying index_label
483
485
sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
@@ -487,11 +489,11 @@ def test_to_sql_index_label(self):
487
489
"Specified index_label not written to database" )
488
490
489
491
# using the index name
490
- temp_frame .index .name = 'index '
492
+ temp_frame .index .name = 'index_name '
491
493
sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
492
494
if_exists = 'replace' )
493
495
frame = sql .read_table ('test_index_label' , self .conn )
494
- self .assertEqual (frame .columns [0 ], 'index ' ,
496
+ self .assertEqual (frame .columns [0 ], 'index_name ' ,
495
497
"Index name not written to database" )
496
498
497
499
# has index name, but specifying index_label
@@ -501,8 +503,74 @@ def test_to_sql_index_label(self):
501
503
self .assertEqual (frame .columns [0 ], 'other_label' ,
502
504
"Specified index_label not written to database" )
503
505
506
+ def test_to_sql_index_label_multiindex (self ):
507
+ temp_frame = DataFrame ({'col1' : range (4 )},
508
+ index = MultiIndex .from_product ([('A0' , 'A1' ), ('B0' , 'B1' )]))
509
+
510
+ # no index name, defaults to 'level_0' and 'level_1'
511
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn )
512
+ frame = sql .read_table ('test_index_label' , self .conn )
513
+ self .assertEqual (frame .columns [0 ], 'level_0' )
514
+ self .assertEqual (frame .columns [1 ], 'level_1' )
515
+
516
+ # specifying index_label
517
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
518
+ if_exists = 'replace' , index_label = ['A' , 'B' ])
519
+ frame = sql .read_table ('test_index_label' , self .conn )
520
+ self .assertEqual (frame .columns [:2 ].tolist (), ['A' , 'B' ],
521
+ "Specified index_labels not written to database" )
522
+
523
+ # using the index name
524
+ temp_frame .index .names = ['A' , 'B' ]
525
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
526
+ if_exists = 'replace' )
527
+ frame = sql .read_table ('test_index_label' , self .conn )
528
+ self .assertEqual (frame .columns [:2 ].tolist (), ['A' , 'B' ],
529
+ "Index names not written to database" )
530
+
531
+ # has index name, but specifying index_label
532
+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
533
+ if_exists = 'replace' , index_label = ['C' , 'D' ])
534
+ frame = sql .read_table ('test_index_label' , self .conn )
535
+ self .assertEqual (frame .columns [:2 ].tolist (), ['C' , 'D' ],
536
+ "Specified index_labels not written to database" )
537
+
538
+ # wrong length of index_label
539
+ self .assertRaises (ValueError , sql .to_sql , temp_frame ,
540
+ 'test_index_label' , self .conn , if_exists = 'replace' ,
541
+ index_label = 'C' )
542
+
543
+ def test_read_table_columns (self ):
544
+ # test columns argument in read_table
545
+ sql .to_sql (self .test_frame1 , 'test_frame' , self .conn )
546
+
547
+ cols = ['A' , 'B' ]
548
+ result = sql .read_table ('test_frame' , self .conn , columns = cols )
549
+ self .assertEqual (result .columns .tolist (), cols ,
550
+ "Columns not correctly selected" )
551
+
552
+ def test_read_table_index_col (self ):
553
+ # test columns argument in read_table
554
+ sql .to_sql (self .test_frame1 , 'test_frame' , self .conn )
555
+
556
+ result = sql .read_table ('test_frame' , self .conn , index_col = "index" )
557
+ self .assertEqual (result .index .names , ["index" ],
558
+ "index_col not correctly set" )
559
+
560
+ result = sql .read_table ('test_frame' , self .conn , index_col = ["A" , "B" ])
561
+ self .assertEqual (result .index .names , ["A" , "B" ],
562
+ "index_col not correctly set" )
563
+
564
+ result = sql .read_table ('test_frame' , self .conn , index_col = ["A" , "B" ],
565
+ columns = ["C" , "D" ])
566
+ self .assertEqual (result .index .names , ["A" , "B" ],
567
+ "index_col not correctly set" )
568
+ self .assertEqual (result .columns .tolist (), ["C" , "D" ],
569
+ "columns not set correctly whith index_col" )
570
+
504
571
505
572
class TestSQLLegacyApi (_TestSQLApi ):
573
+
506
574
"""Test the public legacy API
507
575
"""
508
576
flavor = 'sqlite'
@@ -554,6 +622,23 @@ def test_sql_open_close(self):
554
622
555
623
tm .assert_frame_equal (self .test_frame2 , result )
556
624
625
+ def test_roundtrip (self ):
626
+ # this test otherwise fails, Legacy mode still uses 'pandas_index'
627
+ # as default index column label
628
+ sql .to_sql (self .test_frame1 , 'test_frame_roundtrip' ,
629
+ con = self .conn , flavor = 'sqlite' )
630
+ result = sql .read_sql (
631
+ 'SELECT * FROM test_frame_roundtrip' ,
632
+ con = self .conn ,
633
+ flavor = 'sqlite' )
634
+
635
+ # HACK!
636
+ result .index = self .test_frame1 .index
637
+ result .set_index ('pandas_index' , inplace = True )
638
+ result .index .astype (int )
639
+ result .index .name = None
640
+ tm .assert_frame_equal (result , self .test_frame1 )
641
+
557
642
558
643
class _TestSQLAlchemy (PandasSQLTest ):
559
644
"""
@@ -776,6 +861,16 @@ def setUp(self):
776
861
777
862
self ._load_test1_data ()
778
863
864
+ def _roundtrip (self ):
865
+ # overwrite parent function (level_0 -> pandas_index in legacy mode)
866
+ self .drop_table ('test_frame_roundtrip' )
867
+ self .pandasSQL .to_sql (self .test_frame1 , 'test_frame_roundtrip' )
868
+ result = self .pandasSQL .read_sql ('SELECT * FROM test_frame_roundtrip' )
869
+ result .set_index ('pandas_index' , inplace = True )
870
+ result .index .name = None
871
+
872
+ tm .assert_frame_equal (result , self .test_frame1 )
873
+
779
874
def test_invalid_flavor (self ):
780
875
self .assertRaises (
781
876
NotImplementedError , sql .PandasSQLLegacy , self .conn , 'oracle' )
0 commit comments