Skip to content

ENH: SQL multiindex support #6735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 58 additions & 29 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
Legacy mode not supported
meta : SQLAlchemy meta, optional
If omitted MetaData is reflected from engine
index_col : string, optional
Column to set as index
index_col : string or sequence of strings, optional
Column(s) to set as index.
coerce_float : boolean, default True
Attempt to convert values to non-string, non-numeric objects (like
decimal.Decimal) to floating point. Can result in loss of Precision.
Expand All @@ -324,7 +324,7 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
to the keyword arguments of :func:`pandas.to_datetime`
Especially useful with databases without native Datetime support,
such as SQLite
columns : list
columns : list, optional
List of column names to select from sql table

Returns
Expand All @@ -340,7 +340,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
table = pandas_sql.read_table(table_name,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates)
parse_dates=parse_dates,
columns=columns)

if table is not None:
return table
Expand Down Expand Up @@ -438,19 +439,25 @@ def maybe_asscalar(self, i):
def insert(self):
ins = self.insert_statement()
data_list = []
# to avoid if check for every row
keys = self.frame.columns

if self.index is not None:
for t in self.frame.itertuples():
data = dict((k, self.maybe_asscalar(v))
for k, v in zip(keys, t[1:]))
data[self.index] = self.maybe_asscalar(t[0])
data_list.append(data)
temp = self.frame.copy()
temp.index.names = self.index
try:
temp.reset_index(inplace=True)
except ValueError as err:
raise ValueError(
"duplicate name in index/columns: {0}".format(err))
else:
for t in self.frame.itertuples():
data = dict((k, self.maybe_asscalar(v))
for k, v in zip(keys, t[1:]))
data_list.append(data)
temp = self.frame

keys = temp.columns

for t in temp.itertuples():
data = dict((k, self.maybe_asscalar(v))
for k, v in zip(keys, t[1:]))
data_list.append(data)

self.pd_sql.execute(ins, data_list)

def read(self, coerce_float=True, parse_dates=None, columns=None):
Expand All @@ -459,7 +466,7 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
from sqlalchemy import select
cols = [self.table.c[n] for n in columns]
if self.index is not None:
cols.insert(0, self.table.c[self.index])
[cols.insert(0, self.table.c[idx]) for idx in self.index[::-1]]
sql_select = select(cols)
else:
sql_select = self.table.select()
Expand All @@ -476,22 +483,33 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
if self.index is not None:
self.frame.set_index(self.index, inplace=True)

# Assume if the index in prefix_index format, we gave it a name
# and should return it nameless
if self.index == self.prefix + '_index':
self.frame.index.name = None

return self.frame

def _index_name(self, index, index_label):
# for writing: index=True to include index in sql table
if index is True:
nlevels = self.frame.index.nlevels
# if index_label is specified, set this as index name(s)
if index_label is not None:
return _safe_col_name(index_label)
elif self.frame.index.name is not None:
return _safe_col_name(self.frame.index.name)
if not isinstance(index_label, list):
index_label = [index_label]
if len(index_label) != nlevels:
raise ValueError(
"Length of 'index_label' should match number of "
"levels, which is {0}".format(nlevels))
else:
return index_label
# return the used column labels for the index columns
if nlevels == 1 and 'index' not in self.frame.columns and self.frame.index.name is None:
return ['index']
else:
return self.prefix + '_index'
return [l if l is not None else "level_{0}".format(i)
for i, l in enumerate(self.frame.index.names)]

# for reading: index=(list of) string to specify column to set as index
elif isinstance(index, string_types):
return [index]
elif isinstance(index, list):
return index
else:
return None
Expand All @@ -506,10 +524,10 @@ def _create_table_statement(self):
for name, typ in zip(safe_columns, column_types)]

if self.index is not None:
columns.insert(0, Column(self.index,
self._sqlalchemy_type(
self.frame.index),
index=True))
for i, idx_label in enumerate(self.index[::-1]):
idx_type = self._sqlalchemy_type(
self.frame.index.get_level_values(i))
columns.insert(0, Column(idx_label, idx_type, index=True))

return Table(self.name, self.pd_sql.meta, *columns)

Expand Down Expand Up @@ -787,6 +805,17 @@ def insert(self):
cur.close()
self.pd_sql.con.commit()

def _index_name(self, index, index_label):
if index is True:
if self.frame.index.name is not None:
return _safe_col_name(self.frame.index.name)
else:
return 'pandas_index'
elif isinstance(index, string_types):
return index
else:
return None

def _create_table_statement(self):
"Return a CREATE TABLE statement to suit the contents of a DataFrame."

Expand Down
109 changes: 102 additions & 7 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import nose
import numpy as np

from pandas import DataFrame, Series
from pandas import DataFrame, Series, MultiIndex
from pandas.compat import range, lrange, iteritems
#from pandas.core.datetools import format as date_format

Expand Down Expand Up @@ -266,7 +266,7 @@ def _roundtrip(self):
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')

result.set_index('pandas_index', inplace=True)
result.set_index('level_0', inplace=True)
# result.index.astype(int)

result.index.name = None
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_roundtrip(self):

# HACK!
result.index = self.test_frame1.index
result.set_index('pandas_index', inplace=True)
result.set_index('level_0', inplace=True)
result.index.astype(int)
result.index.name = None
tm.assert_frame_equal(result, self.test_frame1)
Expand Down Expand Up @@ -460,7 +460,9 @@ def test_date_and_index(self):
issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")


class TestSQLApi(_TestSQLApi):

"""Test the public API as it would be used directly
"""
flavor = 'sqlite'
Expand All @@ -474,10 +476,10 @@ def connect(self):
def test_to_sql_index_label(self):
temp_frame = DataFrame({'col1': range(4)})

# no index name, defaults to 'pandas_index'
# no index name, defaults to 'index'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'pandas_index')
self.assertEqual(frame.columns[0], 'index')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the benefit to changing the naming here? Not a big deal, just might be nice to enumerate the reason

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with other places in pandas (eg hdf uses 'index/level_0/1' I think, and I am now following the names that are given in reset_index). As far as I know this is the only place where pandas_index is used. It is also new (in 0.13 writing the index was not included in to_sql), so not really changing the behaviour for the user. See also #6642 (comment) for some discussion.


# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
Expand All @@ -487,11 +489,11 @@ def test_to_sql_index_label(self):
"Specified index_label not written to database")

# using the index name
temp_frame.index.name = 'index'
temp_frame.index.name = 'index_name'
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'index',
self.assertEqual(frame.columns[0], 'index_name',
"Index name not written to database")

# has index name, but specifying index_label
Expand All @@ -501,8 +503,74 @@ def test_to_sql_index_label(self):
self.assertEqual(frame.columns[0], 'other_label',
"Specified index_label not written to database")

def test_to_sql_index_label_multiindex(self):
temp_frame = DataFrame({'col1': range(4)},
index=MultiIndex.from_product([('A0', 'A1'), ('B0', 'B1')]))

# no index name, defaults to 'level_0' and 'level_1'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'level_0')
self.assertEqual(frame.columns[1], 'level_1')

# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label=['A', 'B'])
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
"Specified index_labels not written to database")

# using the index name
temp_frame.index.names = ['A', 'B']
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
"Index names not written to database")

# has index name, but specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label=['C', 'D'])
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'],
"Specified index_labels not written to database")

# wrong length of index_label
self.assertRaises(ValueError, sql.to_sql, temp_frame,
'test_index_label', self.conn, if_exists='replace',
index_label='C')

def test_read_table_columns(self):
# test columns argument in read_table
sql.to_sql(self.test_frame1, 'test_frame', self.conn)

cols = ['A', 'B']
result = sql.read_table('test_frame', self.conn, columns=cols)
self.assertEqual(result.columns.tolist(), cols,
"Columns not correctly selected")

def test_read_table_index_col(self):
# test columns argument in read_table
sql.to_sql(self.test_frame1, 'test_frame', self.conn)

result = sql.read_table('test_frame', self.conn, index_col="index")
self.assertEqual(result.index.names, ["index"],
"index_col not correctly set")

result = sql.read_table('test_frame', self.conn, index_col=["A", "B"])
self.assertEqual(result.index.names, ["A", "B"],
"index_col not correctly set")

result = sql.read_table('test_frame', self.conn, index_col=["A", "B"],
columns=["C", "D"])
self.assertEqual(result.index.names, ["A", "B"],
"index_col not correctly set")
self.assertEqual(result.columns.tolist(), ["C", "D"],
"columns not set correctly whith index_col")


class TestSQLLegacyApi(_TestSQLApi):

"""Test the public legacy API
"""
flavor = 'sqlite'
Expand Down Expand Up @@ -554,6 +622,23 @@ def test_sql_open_close(self):

tm.assert_frame_equal(self.test_frame2, result)

def test_roundtrip(self):
# this test otherwise fails, Legacy mode still uses 'pandas_index'
# as default index column label
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
con=self.conn, flavor='sqlite')
result = sql.read_sql(
'SELECT * FROM test_frame_roundtrip',
con=self.conn,
flavor='sqlite')

# HACK!
result.index = self.test_frame1.index
result.set_index('pandas_index', inplace=True)
result.index.astype(int)
result.index.name = None
tm.assert_frame_equal(result, self.test_frame1)


class _TestSQLAlchemy(PandasSQLTest):
"""
Expand Down Expand Up @@ -776,6 +861,16 @@ def setUp(self):

self._load_test1_data()

def _roundtrip(self):
# overwrite parent function (level_0 -> pandas_index in legacy mode)
self.drop_table('test_frame_roundtrip')
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
result.set_index('pandas_index', inplace=True)
result.index.name = None

tm.assert_frame_equal(result, self.test_frame1)

def test_invalid_flavor(self):
self.assertRaises(
NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')
Expand Down