Skip to content

SQL: add multi-index support to legacy mode #6883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,7 @@ def maybe_asscalar(self, i):
except AttributeError:
return i

def insert(self):
ins = self.insert_statement()
data_list = []

def insert_data(self):
if self.index is not None:
temp = self.frame.copy()
temp.index.names = self.index
Expand All @@ -451,6 +448,12 @@ def insert(self):
else:
temp = self.frame

return temp

def insert(self):
ins = self.insert_statement()
data_list = []
temp = self.insert_data()
keys = temp.columns

for t in temp.itertuples():
Expand Down Expand Up @@ -785,7 +788,7 @@ def insert_statement(self):
wld = _SQL_SYMB[flv]['wld'] # wildcard char

if self.index is not None:
safe_names.insert(0, self.index)
[safe_names.insert(0, idx) for idx in self.index[::-1]]

bracketed_names = [br_l + column + br_r for column in safe_names]
col_names = ','.join(bracketed_names)
Expand All @@ -796,26 +799,18 @@ def insert_statement(self):

def insert(self):
ins = self.insert_statement()
temp = self.insert_data()
data_list = []

for t in temp.itertuples():
data = tuple((self.maybe_asscalar(v) for v in t[1:]))
data_list.append(data)

cur = self.pd_sql.con.cursor()
for r in self.frame.itertuples():
data = [self.maybe_asscalar(v) for v in r[1:]]
if self.index is not None:
data.insert(0, self.maybe_asscalar(r[0]))
cur.execute(ins, tuple(data))
cur.executemany(ins, data_list)
cur.close()
self.pd_sql.con.commit()

def _index_name(self, index, index_label):
if index is True:
if self.frame.index.name is not None:
return _safe_col_name(self.frame.index.name)
else:
return 'pandas_index'
elif isinstance(index, string_types):
return index
else:
return None

def _create_table_statement(self):
"Return a CREATE TABLE statement to suit the contents of a DataFrame."

Expand All @@ -824,8 +819,10 @@ def _create_table_statement(self):
column_types = [self._sql_type_name(typ) for typ in self.frame.dtypes]

if self.index is not None:
safe_columns.insert(0, self.index)
column_types.insert(0, self._sql_type_name(self.frame.index.dtype))
for i, idx_label in enumerate(self.index[::-1]):
safe_columns.insert(0, idx_label)
column_types.insert(0, self._sql_type_name(self.frame.index.get_level_values(i).dtype))

flv = self.pd_sql.flavor

br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
Expand Down Expand Up @@ -935,15 +932,16 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
----------
frame: DataFrame
name: name of SQL table
flavor: {'sqlite', 'mysql', 'postgres'}, default 'sqlite'
flavor: {'sqlite', 'mysql'}, default 'sqlite'
if_exists: {'fail', 'replace', 'append'}, default 'fail'
fail: If table exists, do nothing.
replace: If table exists, drop it, recreate it, and insert data.
append: If table exists, insert data. Create if does not exist.
index_label : ignored (only used in sqlalchemy mode)

"""
table = PandasSQLTableLegacy(
name, self, frame=frame, index=index, if_exists=if_exists)
name, self, frame=frame, index=index, if_exists=if_exists,
index_label=index_label)
table.insert()

def has_table(self, name):
Expand Down Expand Up @@ -991,13 +989,47 @@ def read_frame(*args, **kwargs):
return read_sql(*args, **kwargs)


def write_frame(*args, **kwargs):
def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs):
"""DEPRECIATED - use to_sql

Write records stored in a DataFrame to a SQL database.

Parameters
----------
frame : DataFrame
name : string
con : DBAPI2 connection
flavor : {'sqlite', 'mysql'}, default 'sqlite'
The flavor of SQL to use.
if_exists : {'fail', 'replace', 'append'}, default 'fail'
- fail: If table exists, do nothing.
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
index : boolean, default False
Write DataFrame index as a column

Notes
-----
This function is deprecated in favor of ``to_sql``. There are however
two differences:

- With ``to_sql`` the index is written to the sql database by default. To
keep the behaviour this function you need to specify ``index=False``.
- The new ``to_sql`` function supports sqlalchemy engines to work with
different sql flavors.

See also
--------
pandas.DataFrame.to_sql

"""
warnings.warn("write_frame is depreciated, use to_sql", DeprecationWarning)
return to_sql(*args, **kwargs)

# for backwards compatibility, set index=False when not specified
index = kwargs.pop('index', False)
return to_sql(frame, name, con, flavor=flavor, if_exists=if_exists,
index=index, **kwargs)


# Append wrapped function docstrings
read_frame.__doc__ += read_sql.__doc__
write_frame.__doc__ += to_sql.__doc__
69 changes: 21 additions & 48 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,46 +460,33 @@ def test_date_and_index(self):
issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")


class TestSQLApi(_TestSQLApi):

"""Test the public API as it would be used directly
"""
flavor = 'sqlite'

def connect(self):
if SQLALCHEMY_INSTALLED:
return sqlalchemy.create_engine('sqlite:///:memory:')
else:
raise nose.SkipTest('SQLAlchemy not installed')

def test_to_sql_index_label(self):
temp_frame = DataFrame({'col1': range(4)})

# no index name, defaults to 'index'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'index')

# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label='other_label')
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'other_label',
"Specified index_label not written to database")

# using the index name
temp_frame.index.name = 'index_name'
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'index_name',
"Index name not written to database")

# has index name, but specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label='other_label')
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'other_label',
"Specified index_label not written to database")

Expand All @@ -509,29 +496,29 @@ def test_to_sql_index_label_multiindex(self):

# no index name, defaults to 'level_0' and 'level_1'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'level_0')
self.assertEqual(frame.columns[1], 'level_1')

# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label=['A', 'B'])
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
"Specified index_labels not written to database")

# using the index name
temp_frame.index.names = ['A', 'B']
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
"Index names not written to database")

# has index name, but specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label=['C', 'D'])
frame = sql.read_table('test_index_label', self.conn)
frame = sql.read_sql('SELECT * FROM test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'],
"Specified index_labels not written to database")

Expand All @@ -540,6 +527,19 @@ def test_to_sql_index_label_multiindex(self):
'test_index_label', self.conn, if_exists='replace',
index_label='C')


class TestSQLApi(_TestSQLApi):

"""Test the public API as it would be used directly
"""
flavor = 'sqlite'

def connect(self):
if SQLALCHEMY_INSTALLED:
return sqlalchemy.create_engine('sqlite:///:memory:')
else:
raise nose.SkipTest('SQLAlchemy not installed')

def test_read_table_columns(self):
# test columns argument in read_table
sql.to_sql(self.test_frame1, 'test_frame', self.conn)
Expand Down Expand Up @@ -622,23 +622,6 @@ def test_sql_open_close(self):

tm.assert_frame_equal(self.test_frame2, result)

def test_roundtrip(self):
# this test otherwise fails, Legacy mode still uses 'pandas_index'
# as default index column label
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
con=self.conn, flavor='sqlite')
result = sql.read_sql(
'SELECT * FROM test_frame_roundtrip',
con=self.conn,
flavor='sqlite')

# HACK!
result.index = self.test_frame1.index
result.set_index('pandas_index', inplace=True)
result.index.astype(int)
result.index.name = None
tm.assert_frame_equal(result, self.test_frame1)


class _TestSQLAlchemy(PandasSQLTest):
"""
Expand Down Expand Up @@ -861,16 +844,6 @@ def setUp(self):

self._load_test1_data()

def _roundtrip(self):
# overwrite parent function (level_0 -> pandas_index in legacy mode)
self.drop_table('test_frame_roundtrip')
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
result.set_index('pandas_index', inplace=True)
result.index.name = None

tm.assert_frame_equal(result, self.test_frame1)

def test_invalid_flavor(self):
self.assertRaises(
NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')
Expand Down