Skip to content

Commit 0a26065

Browse files
ENH/TST SQL: add multi-index support to to_sql
1 parent c7928f6 commit 0a26065

File tree

2 files changed

+119
-26
lines changed

2 files changed

+119
-26
lines changed

pandas/io/sql.py

+48-19
Original file line numberDiff line numberDiff line change
@@ -439,19 +439,25 @@ def maybe_asscalar(self, i):
439439
def insert(self):
440440
ins = self.insert_statement()
441441
data_list = []
442-
# to avoid if check for every row
443-
keys = self.frame.columns
442+
444443
if self.index is not None:
445-
for t in self.frame.itertuples():
446-
data = dict((k, self.maybe_asscalar(v))
447-
for k, v in zip(keys, t[1:]))
448-
data[self.index] = self.maybe_asscalar(t[0])
449-
data_list.append(data)
444+
temp = self.frame.copy()
445+
temp.index.names = self.index
446+
try:
447+
temp.reset_index(inplace=True)
448+
except ValueError as err:
449+
raise ValueError(
450+
"duplicate name in index/columns: {0}".format(err))
450451
else:
451-
for t in self.frame.itertuples():
452-
data = dict((k, self.maybe_asscalar(v))
453-
for k, v in zip(keys, t[1:]))
454-
data_list.append(data)
452+
temp = self.frame
453+
454+
keys = temp.columns
455+
456+
for t in temp.itertuples():
457+
data = dict((k, self.maybe_asscalar(v))
458+
for k, v in zip(keys, t[1:]))
459+
data_list.append(data)
460+
455461
self.pd_sql.execute(ins, data_list)
456462

457463
def read(self, coerce_float=True, parse_dates=None, columns=None):
@@ -486,12 +492,24 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
486492

487493
def _index_name(self, index, index_label):
488494
if index is True:
495+
nlevels = self.frame.index.nlevels
496+
# if index_label is specified, set this as index name(s)
489497
if index_label is not None:
490-
return _safe_col_name(index_label)
491-
elif self.frame.index.name is not None:
492-
return _safe_col_name(self.frame.index.name)
498+
if not isinstance(index_label, list):
499+
index_label = [index_label]
500+
if len(index_label) != nlevels:
501+
raise ValueError(
502+
"Length of 'index_label' should match number of "
503+
"levels, which is {0}".format(nlevels))
504+
else:
505+
return index_label
506+
# return the used column labels for the index columns
507+
if nlevels == 1 and 'index' not in self.frame.columns and self.frame.index.name is None:
508+
return ['index']
493509
else:
494-
return self.prefix + '_index'
510+
return [l if l is not None else "level_{0}".format(i)
511+
for i, l in enumerate(self.frame.index.names)]
512+
495513
elif isinstance(index, string_types):
496514
return index
497515
else:
@@ -507,10 +525,10 @@ def _create_table_statement(self):
507525
for name, typ in zip(safe_columns, column_types)]
508526

509527
if self.index is not None:
510-
columns.insert(0, Column(self.index,
511-
self._sqlalchemy_type(
512-
self.frame.index),
513-
index=True))
528+
for i, idx_label in enumerate(self.index[::-1]):
529+
idx_type = self._sqlalchemy_type(
530+
self.frame.index.get_level_values(i))
531+
columns.insert(0, Column(idx_label, idx_type, index=True))
514532

515533
return Table(self.name, self.pd_sql.meta, *columns)
516534

@@ -788,6 +806,17 @@ def insert(self):
788806
cur.close()
789807
self.pd_sql.con.commit()
790808

809+
def _index_name(self, index, index_label):
810+
if index is True:
811+
if self.frame.index.name is not None:
812+
return _safe_col_name(self.frame.index.name)
813+
else:
814+
return 'pandas_index'
815+
elif isinstance(index, string_types):
816+
return index
817+
else:
818+
return None
819+
791820
def _create_table_statement(self):
792821
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
793822

pandas/io/tests/test_sql.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import nose
88
import numpy as np
99

10-
from pandas import DataFrame, Series
10+
from pandas import DataFrame, Series, MultiIndex
1111
from pandas.compat import range, lrange, iteritems
1212
#from pandas.core.datetools import format as date_format
1313

@@ -266,7 +266,7 @@ def _roundtrip(self):
266266
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
267267
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
268268

269-
result.set_index('pandas_index', inplace=True)
269+
result.set_index('level_0', inplace=True)
270270
# result.index.astype(int)
271271

272272
result.index.name = None
@@ -391,7 +391,7 @@ def test_roundtrip(self):
391391

392392
# HACK!
393393
result.index = self.test_frame1.index
394-
result.set_index('pandas_index', inplace=True)
394+
result.set_index('level_0', inplace=True)
395395
result.index.astype(int)
396396
result.index.name = None
397397
tm.assert_frame_equal(result, self.test_frame1)
@@ -476,10 +476,10 @@ def connect(self):
476476
def test_to_sql_index_label(self):
477477
temp_frame = DataFrame({'col1': range(4)})
478478

479-
# no index name, defaults to 'pandas_index'
479+
# no index name, defaults to 'index'
480480
sql.to_sql(temp_frame, 'test_index_label', self.conn)
481481
frame = sql.read_table('test_index_label', self.conn)
482-
self.assertEqual(frame.columns[0], 'pandas_index')
482+
self.assertEqual(frame.columns[0], 'index')
483483

484484
# specifying index_label
485485
sql.to_sql(temp_frame, 'test_index_label', self.conn,
@@ -489,11 +489,11 @@ def test_to_sql_index_label(self):
489489
"Specified index_label not written to database")
490490

491491
# using the index name
492-
temp_frame.index.name = 'index'
492+
temp_frame.index.name = 'index_name'
493493
sql.to_sql(temp_frame, 'test_index_label', self.conn,
494494
if_exists='replace')
495495
frame = sql.read_table('test_index_label', self.conn)
496-
self.assertEqual(frame.columns[0], 'index',
496+
self.assertEqual(frame.columns[0], 'index_name',
497497
"Index name not written to database")
498498

499499
# has index name, but specifying index_label
@@ -503,6 +503,43 @@ def test_to_sql_index_label(self):
503503
self.assertEqual(frame.columns[0], 'other_label',
504504
"Specified index_label not written to database")
505505

506+
def test_to_sql_index_label_multiindex(self):
507+
temp_frame = DataFrame({'col1': range(4)},
508+
index=MultiIndex.from_product([('A0', 'A1'), ('B0', 'B1')]))
509+
510+
# no index name, defaults to 'level_0' and 'level_1'
511+
sql.to_sql(temp_frame, 'test_index_label', self.conn)
512+
frame = sql.read_table('test_index_label', self.conn)
513+
self.assertEqual(frame.columns[0], 'level_0')
514+
self.assertEqual(frame.columns[1], 'level_1')
515+
516+
# specifying index_label
517+
sql.to_sql(temp_frame, 'test_index_label', self.conn,
518+
if_exists='replace', index_label=['A', 'B'])
519+
frame = sql.read_table('test_index_label', self.conn)
520+
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
521+
"Specified index_labels not written to database")
522+
523+
# using the index name
524+
temp_frame.index.names = ['A', 'B']
525+
sql.to_sql(temp_frame, 'test_index_label', self.conn,
526+
if_exists='replace')
527+
frame = sql.read_table('test_index_label', self.conn)
528+
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
529+
"Index names not written to database")
530+
531+
# has index name, but specifying index_label
532+
sql.to_sql(temp_frame, 'test_index_label', self.conn,
533+
if_exists='replace', index_label=['C', 'D'])
534+
frame = sql.read_table('test_index_label', self.conn)
535+
self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'],
536+
"Specified index_labels not written to database")
537+
538+
# wrong length of index_label
539+
self.assertRaises(ValueError, sql.to_sql, temp_frame,
540+
'test_index_label', self.conn, if_exists='replace',
541+
index_label='C')
542+
506543
def test_read_table_columns(self):
507544
# test columns argument in read_table
508545
sql.to_sql(self.test_frame1, 'test_frame', self.conn)
@@ -566,6 +603,23 @@ def test_sql_open_close(self):
566603

567604
tm.assert_frame_equal(self.test_frame2, result)
568605

606+
def test_roundtrip(self):
607+
# this test otherwise fails, Legacy mode still uses 'pandas_index'
608+
# as default index column label
609+
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
610+
con=self.conn, flavor='sqlite')
611+
result = sql.read_sql(
612+
'SELECT * FROM test_frame_roundtrip',
613+
con=self.conn,
614+
flavor='sqlite')
615+
616+
# HACK!
617+
result.index = self.test_frame1.index
618+
result.set_index('pandas_index', inplace=True)
619+
result.index.astype(int)
620+
result.index.name = None
621+
tm.assert_frame_equal(result, self.test_frame1)
622+
569623

570624
class _TestSQLAlchemy(PandasSQLTest):
571625
"""
@@ -788,6 +842,16 @@ def setUp(self):
788842

789843
self._load_test1_data()
790844

845+
def _roundtrip(self):
846+
# overwrite parent function (level_0 -> pandas_index in legacy mode)
847+
self.drop_table('test_frame_roundtrip')
848+
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
849+
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
850+
result.set_index('pandas_index', inplace=True)
851+
result.index.name = None
852+
853+
tm.assert_frame_equal(result, self.test_frame1)
854+
791855
def test_invalid_flavor(self):
792856
self.assertRaises(
793857
NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')

0 commit comments

Comments
 (0)