Skip to content

Commit ad1f47d

Browse files
Merge pull request #6735 from jorisvandenbossche/sql-multiindex
ENH: SQL multiindex support
2 parents 8e36ff4 + 18bd0d6 commit ad1f47d

File tree

2 files changed

+160
-36
lines changed

2 files changed

+160
-36
lines changed

pandas/io/sql.py

+58-29
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
310310
Legacy mode not supported
311311
meta : SQLAlchemy meta, optional
312312
If omitted MetaData is reflected from engine
313-
index_col : string, optional
314-
Column to set as index
313+
index_col : string or sequence of strings, optional
314+
Column(s) to set as index.
315315
coerce_float : boolean, default True
316316
Attempt to convert values to non-string, non-numeric objects (like
317317
decimal.Decimal) to floating point. Can result in loss of Precision.
@@ -324,7 +324,7 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
324324
to the keyword arguments of :func:`pandas.to_datetime`
325325
Especially useful with databases without native Datetime support,
326326
such as SQLite
327-
columns : list
327+
columns : list, optional
328328
List of column names to select from sql table
329329
330330
Returns
@@ -340,7 +340,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
340340
table = pandas_sql.read_table(table_name,
341341
index_col=index_col,
342342
coerce_float=coerce_float,
343-
parse_dates=parse_dates)
343+
parse_dates=parse_dates,
344+
columns=columns)
344345

345346
if table is not None:
346347
return table
@@ -438,19 +439,25 @@ def maybe_asscalar(self, i):
438439
def insert(self):
439440
ins = self.insert_statement()
440441
data_list = []
441-
# to avoid if check for every row
442-
keys = self.frame.columns
442+
443443
if self.index is not None:
444-
for t in self.frame.itertuples():
445-
data = dict((k, self.maybe_asscalar(v))
446-
for k, v in zip(keys, t[1:]))
447-
data[self.index] = self.maybe_asscalar(t[0])
448-
data_list.append(data)
444+
temp = self.frame.copy()
445+
temp.index.names = self.index
446+
try:
447+
temp.reset_index(inplace=True)
448+
except ValueError as err:
449+
raise ValueError(
450+
"duplicate name in index/columns: {0}".format(err))
449451
else:
450-
for t in self.frame.itertuples():
451-
data = dict((k, self.maybe_asscalar(v))
452-
for k, v in zip(keys, t[1:]))
453-
data_list.append(data)
452+
temp = self.frame
453+
454+
keys = temp.columns
455+
456+
for t in temp.itertuples():
457+
data = dict((k, self.maybe_asscalar(v))
458+
for k, v in zip(keys, t[1:]))
459+
data_list.append(data)
460+
454461
self.pd_sql.execute(ins, data_list)
455462

456463
def read(self, coerce_float=True, parse_dates=None, columns=None):
@@ -459,7 +466,7 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
459466
from sqlalchemy import select
460467
cols = [self.table.c[n] for n in columns]
461468
if self.index is not None:
462-
cols.insert(0, self.table.c[self.index])
469+
[cols.insert(0, self.table.c[idx]) for idx in self.index[::-1]]
463470
sql_select = select(cols)
464471
else:
465472
sql_select = self.table.select()
@@ -476,22 +483,33 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
476483
if self.index is not None:
477484
self.frame.set_index(self.index, inplace=True)
478485

479-
# Assume if the index in prefix_index format, we gave it a name
480-
# and should return it nameless
481-
if self.index == self.prefix + '_index':
482-
self.frame.index.name = None
483-
484486
return self.frame
485487

486488
def _index_name(self, index, index_label):
489+
# for writing: index=True to include index in sql table
487490
if index is True:
491+
nlevels = self.frame.index.nlevels
492+
# if index_label is specified, set this as index name(s)
488493
if index_label is not None:
489-
return _safe_col_name(index_label)
490-
elif self.frame.index.name is not None:
491-
return _safe_col_name(self.frame.index.name)
494+
if not isinstance(index_label, list):
495+
index_label = [index_label]
496+
if len(index_label) != nlevels:
497+
raise ValueError(
498+
"Length of 'index_label' should match number of "
499+
"levels, which is {0}".format(nlevels))
500+
else:
501+
return index_label
502+
# return the used column labels for the index columns
503+
if nlevels == 1 and 'index' not in self.frame.columns and self.frame.index.name is None:
504+
return ['index']
492505
else:
493-
return self.prefix + '_index'
506+
return [l if l is not None else "level_{0}".format(i)
507+
for i, l in enumerate(self.frame.index.names)]
508+
509+
# for reading: index=(list of) string to specify column to set as index
494510
elif isinstance(index, string_types):
511+
return [index]
512+
elif isinstance(index, list):
495513
return index
496514
else:
497515
return None
@@ -506,10 +524,10 @@ def _create_table_statement(self):
506524
for name, typ in zip(safe_columns, column_types)]
507525

508526
if self.index is not None:
509-
columns.insert(0, Column(self.index,
510-
self._sqlalchemy_type(
511-
self.frame.index),
512-
index=True))
527+
for i, idx_label in enumerate(self.index[::-1]):
528+
idx_type = self._sqlalchemy_type(
529+
self.frame.index.get_level_values(i))
530+
columns.insert(0, Column(idx_label, idx_type, index=True))
513531

514532
return Table(self.name, self.pd_sql.meta, *columns)
515533

@@ -787,6 +805,17 @@ def insert(self):
787805
cur.close()
788806
self.pd_sql.con.commit()
789807

808+
def _index_name(self, index, index_label):
809+
if index is True:
810+
if self.frame.index.name is not None:
811+
return _safe_col_name(self.frame.index.name)
812+
else:
813+
return 'pandas_index'
814+
elif isinstance(index, string_types):
815+
return index
816+
else:
817+
return None
818+
790819
def _create_table_statement(self):
791820
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
792821

pandas/io/tests/test_sql.py

+102-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import nose
88
import numpy as np
99

10-
from pandas import DataFrame, Series
10+
from pandas import DataFrame, Series, MultiIndex
1111
from pandas.compat import range, lrange, iteritems
1212
#from pandas.core.datetools import format as date_format
1313

@@ -266,7 +266,7 @@ def _roundtrip(self):
266266
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
267267
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
268268

269-
result.set_index('pandas_index', inplace=True)
269+
result.set_index('level_0', inplace=True)
270270
# result.index.astype(int)
271271

272272
result.index.name = None
@@ -391,7 +391,7 @@ def test_roundtrip(self):
391391

392392
# HACK!
393393
result.index = self.test_frame1.index
394-
result.set_index('pandas_index', inplace=True)
394+
result.set_index('level_0', inplace=True)
395395
result.index.astype(int)
396396
result.index.name = None
397397
tm.assert_frame_equal(result, self.test_frame1)
@@ -460,7 +460,9 @@ def test_date_and_index(self):
460460
issubclass(df.IntDateCol.dtype.type, np.datetime64),
461461
"IntDateCol loaded with incorrect type")
462462

463+
463464
class TestSQLApi(_TestSQLApi):
465+
464466
"""Test the public API as it would be used directly
465467
"""
466468
flavor = 'sqlite'
@@ -474,10 +476,10 @@ def connect(self):
474476
def test_to_sql_index_label(self):
475477
temp_frame = DataFrame({'col1': range(4)})
476478

477-
# no index name, defaults to 'pandas_index'
479+
# no index name, defaults to 'index'
478480
sql.to_sql(temp_frame, 'test_index_label', self.conn)
479481
frame = sql.read_table('test_index_label', self.conn)
480-
self.assertEqual(frame.columns[0], 'pandas_index')
482+
self.assertEqual(frame.columns[0], 'index')
481483

482484
# specifying index_label
483485
sql.to_sql(temp_frame, 'test_index_label', self.conn,
@@ -487,11 +489,11 @@ def test_to_sql_index_label(self):
487489
"Specified index_label not written to database")
488490

489491
# using the index name
490-
temp_frame.index.name = 'index'
492+
temp_frame.index.name = 'index_name'
491493
sql.to_sql(temp_frame, 'test_index_label', self.conn,
492494
if_exists='replace')
493495
frame = sql.read_table('test_index_label', self.conn)
494-
self.assertEqual(frame.columns[0], 'index',
496+
self.assertEqual(frame.columns[0], 'index_name',
495497
"Index name not written to database")
496498

497499
# has index name, but specifying index_label
@@ -501,8 +503,74 @@ def test_to_sql_index_label(self):
501503
self.assertEqual(frame.columns[0], 'other_label',
502504
"Specified index_label not written to database")
503505

506+
def test_to_sql_index_label_multiindex(self):
507+
temp_frame = DataFrame({'col1': range(4)},
508+
index=MultiIndex.from_product([('A0', 'A1'), ('B0', 'B1')]))
509+
510+
# no index name, defaults to 'level_0' and 'level_1'
511+
sql.to_sql(temp_frame, 'test_index_label', self.conn)
512+
frame = sql.read_table('test_index_label', self.conn)
513+
self.assertEqual(frame.columns[0], 'level_0')
514+
self.assertEqual(frame.columns[1], 'level_1')
515+
516+
# specifying index_label
517+
sql.to_sql(temp_frame, 'test_index_label', self.conn,
518+
if_exists='replace', index_label=['A', 'B'])
519+
frame = sql.read_table('test_index_label', self.conn)
520+
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
521+
"Specified index_labels not written to database")
522+
523+
# using the index name
524+
temp_frame.index.names = ['A', 'B']
525+
sql.to_sql(temp_frame, 'test_index_label', self.conn,
526+
if_exists='replace')
527+
frame = sql.read_table('test_index_label', self.conn)
528+
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
529+
"Index names not written to database")
530+
531+
# has index name, but specifying index_label
532+
sql.to_sql(temp_frame, 'test_index_label', self.conn,
533+
if_exists='replace', index_label=['C', 'D'])
534+
frame = sql.read_table('test_index_label', self.conn)
535+
self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'],
536+
"Specified index_labels not written to database")
537+
538+
# wrong length of index_label
539+
self.assertRaises(ValueError, sql.to_sql, temp_frame,
540+
'test_index_label', self.conn, if_exists='replace',
541+
index_label='C')
542+
543+
def test_read_table_columns(self):
544+
# test columns argument in read_table
545+
sql.to_sql(self.test_frame1, 'test_frame', self.conn)
546+
547+
cols = ['A', 'B']
548+
result = sql.read_table('test_frame', self.conn, columns=cols)
549+
self.assertEqual(result.columns.tolist(), cols,
550+
"Columns not correctly selected")
551+
552+
def test_read_table_index_col(self):
553+
# test columns argument in read_table
554+
sql.to_sql(self.test_frame1, 'test_frame', self.conn)
555+
556+
result = sql.read_table('test_frame', self.conn, index_col="index")
557+
self.assertEqual(result.index.names, ["index"],
558+
"index_col not correctly set")
559+
560+
result = sql.read_table('test_frame', self.conn, index_col=["A", "B"])
561+
self.assertEqual(result.index.names, ["A", "B"],
562+
"index_col not correctly set")
563+
564+
result = sql.read_table('test_frame', self.conn, index_col=["A", "B"],
565+
columns=["C", "D"])
566+
self.assertEqual(result.index.names, ["A", "B"],
567+
"index_col not correctly set")
568+
self.assertEqual(result.columns.tolist(), ["C", "D"],
569+
"columns not set correctly whith index_col")
570+
504571

505572
class TestSQLLegacyApi(_TestSQLApi):
573+
506574
"""Test the public legacy API
507575
"""
508576
flavor = 'sqlite'
@@ -554,6 +622,23 @@ def test_sql_open_close(self):
554622

555623
tm.assert_frame_equal(self.test_frame2, result)
556624

625+
def test_roundtrip(self):
626+
# this test otherwise fails, Legacy mode still uses 'pandas_index'
627+
# as default index column label
628+
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
629+
con=self.conn, flavor='sqlite')
630+
result = sql.read_sql(
631+
'SELECT * FROM test_frame_roundtrip',
632+
con=self.conn,
633+
flavor='sqlite')
634+
635+
# HACK!
636+
result.index = self.test_frame1.index
637+
result.set_index('pandas_index', inplace=True)
638+
result.index.astype(int)
639+
result.index.name = None
640+
tm.assert_frame_equal(result, self.test_frame1)
641+
557642

558643
class _TestSQLAlchemy(PandasSQLTest):
559644
"""
@@ -776,6 +861,16 @@ def setUp(self):
776861

777862
self._load_test1_data()
778863

864+
def _roundtrip(self):
865+
# overwrite parent function (level_0 -> pandas_index in legacy mode)
866+
self.drop_table('test_frame_roundtrip')
867+
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
868+
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
869+
result.set_index('pandas_index', inplace=True)
870+
result.index.name = None
871+
872+
tm.assert_frame_equal(result, self.test_frame1)
873+
779874
def test_invalid_flavor(self):
780875
self.assertRaises(
781876
NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')

0 commit comments

Comments
 (0)