Skip to content

Commit ca44b2e

Browse files
SQL: add multi-index support to legacy mode (pandas-dev#6881)
and at once also index_label kwarg support
1 parent 18bd0d6 commit ca44b2e

File tree

2 files changed

+45
-74
lines changed

2 files changed

+45
-74
lines changed

pandas/io/sql.py

+24-26
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,7 @@ def maybe_asscalar(self, i):
436436
except AttributeError:
437437
return i
438438

439-
def insert(self):
440-
ins = self.insert_statement()
441-
data_list = []
442-
439+
def insert_data(self):
443440
if self.index is not None:
444441
temp = self.frame.copy()
445442
temp.index.names = self.index
@@ -451,6 +448,12 @@ def insert(self):
451448
else:
452449
temp = self.frame
453450

451+
return temp
452+
453+
def insert(self):
454+
ins = self.insert_statement()
455+
data_list = []
456+
temp = self.insert_data()
454457
keys = temp.columns
455458

456459
for t in temp.itertuples():
@@ -785,7 +788,7 @@ def insert_statement(self):
785788
wld = _SQL_SYMB[flv]['wld'] # wildcard char
786789

787790
if self.index is not None:
788-
safe_names.insert(0, self.index)
791+
[safe_names.insert(0, idx) for idx in self.index[::-1]]
789792

790793
bracketed_names = [br_l + column + br_r for column in safe_names]
791794
col_names = ','.join(bracketed_names)
@@ -796,26 +799,18 @@ def insert_statement(self):
796799

797800
def insert(self):
798801
ins = self.insert_statement()
802+
temp = self.insert_data()
803+
data_list = []
804+
805+
for t in temp.itertuples():
806+
data = tuple((self.maybe_asscalar(v) for v in t[1:]))
807+
data_list.append(data)
808+
799809
cur = self.pd_sql.con.cursor()
800-
for r in self.frame.itertuples():
801-
data = [self.maybe_asscalar(v) for v in r[1:]]
802-
if self.index is not None:
803-
data.insert(0, self.maybe_asscalar(r[0]))
804-
cur.execute(ins, tuple(data))
810+
cur.executemany(ins, data_list)
805811
cur.close()
806812
self.pd_sql.con.commit()
807813

808-
def _index_name(self, index, index_label):
809-
if index is True:
810-
if self.frame.index.name is not None:
811-
return _safe_col_name(self.frame.index.name)
812-
else:
813-
return 'pandas_index'
814-
elif isinstance(index, string_types):
815-
return index
816-
else:
817-
return None
818-
819814
def _create_table_statement(self):
820815
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
821816

@@ -824,8 +819,10 @@ def _create_table_statement(self):
824819
column_types = [self._sql_type_name(typ) for typ in self.frame.dtypes]
825820

826821
if self.index is not None:
827-
safe_columns.insert(0, self.index)
828-
column_types.insert(0, self._sql_type_name(self.frame.index.dtype))
822+
for i, idx_label in enumerate(self.index[::-1]):
823+
safe_columns.insert(0, idx_label)
824+
column_types.insert(0, self._sql_type_name(self.frame.index.get_level_values(i).dtype))
825+
829826
flv = self.pd_sql.flavor
830827

831828
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
@@ -935,15 +932,16 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
935932
----------
936933
frame: DataFrame
937934
name: name of SQL table
938-
flavor: {'sqlite', 'mysql', 'postgres'}, default 'sqlite'
935+
flavor: {'sqlite', 'mysql'}, default 'sqlite'
939936
if_exists: {'fail', 'replace', 'append'}, default 'fail'
940937
fail: If table exists, do nothing.
941938
replace: If table exists, drop it, recreate it, and insert data.
942939
append: If table exists, insert data. Create if does not exist.
943-
index_label : ignored (only used in sqlalchemy mode)
940+
944941
"""
945942
table = PandasSQLTableLegacy(
946-
name, self, frame=frame, index=index, if_exists=if_exists)
943+
name, self, frame=frame, index=index, if_exists=if_exists,
944+
index_label=index_label)
947945
table.insert()
948946

949947
def has_table(self, name):

pandas/io/tests/test_sql.py

+21-48
Original file line numberDiff line numberDiff line change
@@ -460,46 +460,33 @@ def test_date_and_index(self):
460460
issubclass(df.IntDateCol.dtype.type, np.datetime64),
461461
"IntDateCol loaded with incorrect type")
462462

463-
464-
class TestSQLApi(_TestSQLApi):
465-
466-
"""Test the public API as it would be used directly
467-
"""
468-
flavor = 'sqlite'
469-
470-
def connect(self):
471-
if SQLALCHEMY_INSTALLED:
472-
return sqlalchemy.create_engine('sqlite:///:memory:')
473-
else:
474-
raise nose.SkipTest('SQLAlchemy not installed')
475-
476463
def test_to_sql_index_label(self):
477464
temp_frame = DataFrame({'col1': range(4)})
478465

479466
# no index name, defaults to 'index'
480467
sql.to_sql(temp_frame, 'test_index_label', self.conn)
481-
frame = sql.read_table('test_index_label', self.conn)
468+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
482469
self.assertEqual(frame.columns[0], 'index')
483470

484471
# specifying index_label
485472
sql.to_sql(temp_frame, 'test_index_label', self.conn,
486473
if_exists='replace', index_label='other_label')
487-
frame = sql.read_table('test_index_label', self.conn)
474+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
488475
self.assertEqual(frame.columns[0], 'other_label',
489476
"Specified index_label not written to database")
490477

491478
# using the index name
492479
temp_frame.index.name = 'index_name'
493480
sql.to_sql(temp_frame, 'test_index_label', self.conn,
494481
if_exists='replace')
495-
frame = sql.read_table('test_index_label', self.conn)
482+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
496483
self.assertEqual(frame.columns[0], 'index_name',
497484
"Index name not written to database")
498485

499486
# has index name, but specifying index_label
500487
sql.to_sql(temp_frame, 'test_index_label', self.conn,
501488
if_exists='replace', index_label='other_label')
502-
frame = sql.read_table('test_index_label', self.conn)
489+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
503490
self.assertEqual(frame.columns[0], 'other_label',
504491
"Specified index_label not written to database")
505492

@@ -509,29 +496,29 @@ def test_to_sql_index_label_multiindex(self):
509496

510497
# no index name, defaults to 'level_0' and 'level_1'
511498
sql.to_sql(temp_frame, 'test_index_label', self.conn)
512-
frame = sql.read_table('test_index_label', self.conn)
499+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
513500
self.assertEqual(frame.columns[0], 'level_0')
514501
self.assertEqual(frame.columns[1], 'level_1')
515502

516503
# specifying index_label
517504
sql.to_sql(temp_frame, 'test_index_label', self.conn,
518505
if_exists='replace', index_label=['A', 'B'])
519-
frame = sql.read_table('test_index_label', self.conn)
506+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
520507
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
521508
"Specified index_labels not written to database")
522509

523510
# using the index name
524511
temp_frame.index.names = ['A', 'B']
525512
sql.to_sql(temp_frame, 'test_index_label', self.conn,
526513
if_exists='replace')
527-
frame = sql.read_table('test_index_label', self.conn)
514+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
528515
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
529516
"Index names not written to database")
530517

531518
# has index name, but specifying index_label
532519
sql.to_sql(temp_frame, 'test_index_label', self.conn,
533520
if_exists='replace', index_label=['C', 'D'])
534-
frame = sql.read_table('test_index_label', self.conn)
521+
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
535522
self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'],
536523
"Specified index_labels not written to database")
537524

@@ -540,6 +527,19 @@ def test_to_sql_index_label_multiindex(self):
540527
'test_index_label', self.conn, if_exists='replace',
541528
index_label='C')
542529

530+
531+
class TestSQLApi(_TestSQLApi):
532+
533+
"""Test the public API as it would be used directly
534+
"""
535+
flavor = 'sqlite'
536+
537+
def connect(self):
538+
if SQLALCHEMY_INSTALLED:
539+
return sqlalchemy.create_engine('sqlite:///:memory:')
540+
else:
541+
raise nose.SkipTest('SQLAlchemy not installed')
542+
543543
def test_read_table_columns(self):
544544
# test columns argument in read_table
545545
sql.to_sql(self.test_frame1, 'test_frame', self.conn)
@@ -622,23 +622,6 @@ def test_sql_open_close(self):
622622

623623
tm.assert_frame_equal(self.test_frame2, result)
624624

625-
def test_roundtrip(self):
626-
# this test otherwise fails, Legacy mode still uses 'pandas_index'
627-
# as default index column label
628-
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
629-
con=self.conn, flavor='sqlite')
630-
result = sql.read_sql(
631-
'SELECT * FROM test_frame_roundtrip',
632-
con=self.conn,
633-
flavor='sqlite')
634-
635-
# HACK!
636-
result.index = self.test_frame1.index
637-
result.set_index('pandas_index', inplace=True)
638-
result.index.astype(int)
639-
result.index.name = None
640-
tm.assert_frame_equal(result, self.test_frame1)
641-
642625

643626
class _TestSQLAlchemy(PandasSQLTest):
644627
"""
@@ -861,16 +844,6 @@ def setUp(self):
861844

862845
self._load_test1_data()
863846

864-
def _roundtrip(self):
865-
# overwrite parent function (level_0 -> pandas_index in legacy mode)
866-
self.drop_table('test_frame_roundtrip')
867-
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
868-
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
869-
result.set_index('pandas_index', inplace=True)
870-
result.index.name = None
871-
872-
tm.assert_frame_equal(result, self.test_frame1)
873-
874847
def test_invalid_flavor(self):
875848
self.assertRaises(
876849
NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')

0 commit comments

Comments
 (0)