diff --git a/pandas/io/sql.py b/pandas/io/sql.py index efb8ce07ab60e..a80e8049ae627 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -310,8 +310,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True, Legacy mode not supported meta : SQLAlchemy meta, optional If omitted MetaData is reflected from engine - index_col : string, optional - Column to set as index + index_col : string or sequence of strings, optional + Column(s) to set as index. coerce_float : boolean, default True Attempt to convert values to non-string, non-numeric objects (like decimal.Decimal) to floating point. Can result in loss of Precision. @@ -324,7 +324,7 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True, to the keyword arguments of :func:`pandas.to_datetime` Especially useful with databases without native Datetime support, such as SQLite - columns : list + columns : list, optional List of column names to select from sql table Returns @@ -340,7 +340,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True, table = pandas_sql.read_table(table_name, index_col=index_col, coerce_float=coerce_float, - parse_dates=parse_dates) + parse_dates=parse_dates, + columns=columns) if table is not None: return table @@ -438,19 +439,25 @@ def maybe_asscalar(self, i): def insert(self): ins = self.insert_statement() data_list = [] - # to avoid if check for every row - keys = self.frame.columns + if self.index is not None: - for t in self.frame.itertuples(): - data = dict((k, self.maybe_asscalar(v)) - for k, v in zip(keys, t[1:])) - data[self.index] = self.maybe_asscalar(t[0]) - data_list.append(data) + temp = self.frame.copy() + temp.index.names = self.index + try: + temp.reset_index(inplace=True) + except ValueError as err: + raise ValueError( + "duplicate name in index/columns: {0}".format(err)) else: - for t in self.frame.itertuples(): - data = dict((k, self.maybe_asscalar(v)) - for k, v in zip(keys, t[1:])) - data_list.append(data) + temp = self.frame + + keys = temp.columns + + for t in temp.itertuples(): + data = dict((k, self.maybe_asscalar(v)) + for k, v in zip(keys, t[1:])) + data_list.append(data) + self.pd_sql.execute(ins, data_list) def read(self, coerce_float=True, parse_dates=None, columns=None): @@ -459,7 +466,7 @@ def read(self, coerce_float=True, parse_dates=None, columns=None): from sqlalchemy import select cols = [self.table.c[n] for n in columns] if self.index is not None: - cols.insert(0, self.table.c[self.index]) + [cols.insert(0, self.table.c[idx]) for idx in self.index[::-1]] sql_select = select(cols) else: sql_select = self.table.select() @@ -476,22 +483,33 @@ def read(self, coerce_float=True, parse_dates=None, columns=None): if self.index is not None: self.frame.set_index(self.index, inplace=True) - # Assume if the index in prefix_index format, we gave it a name - # and should return it nameless - if self.index == self.prefix + '_index': - self.frame.index.name = None - return self.frame def _index_name(self, index, index_label): + # for writing: index=True to include index in sql table if index is True: + nlevels = self.frame.index.nlevels + # if index_label is specified, set this as index name(s) if index_label is not None: - return _safe_col_name(index_label) - elif self.frame.index.name is not None: - return _safe_col_name(self.frame.index.name) + if not isinstance(index_label, list): + index_label = [index_label] + if len(index_label) != nlevels: + raise ValueError( + "Length of 'index_label' should match number of " + "levels, which is {0}".format(nlevels)) + else: + return index_label + # return the used column labels for the index columns + if nlevels == 1 and 'index' not in self.frame.columns and self.frame.index.name is None: + return ['index'] else: - return self.prefix + '_index' + return [l if l is not None else "level_{0}".format(i) + for i, l in enumerate(self.frame.index.names)] + + # for reading: index=(list of) string to specify column to set as index elif isinstance(index, string_types): + return [index] + elif isinstance(index, list): return index else: return None @@ -506,10 +524,10 @@ def _create_table_statement(self): for name, typ in zip(safe_columns, column_types)] if self.index is not None: - columns.insert(0, Column(self.index, - self._sqlalchemy_type( - self.frame.index), - index=True)) + for i, idx_label in enumerate(self.index[::-1]): + idx_type = self._sqlalchemy_type( + self.frame.index.get_level_values(i)) + columns.insert(0, Column(idx_label, idx_type, index=True)) return Table(self.name, self.pd_sql.meta, *columns) @@ -787,6 +805,17 @@ def insert(self): cur.close() self.pd_sql.con.commit() + def _index_name(self, index, index_label): + if index is True: + if self.frame.index.name is not None: + return _safe_col_name(self.frame.index.name) + else: + return 'pandas_index' + elif isinstance(index, string_types): + return index + else: + return None + def _create_table_statement(self): "Return a CREATE TABLE statement to suit the contents of a DataFrame." diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 57918e8315102..aa1b2516e4fb6 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -7,7 +7,7 @@ import nose import numpy as np -from pandas import DataFrame, Series +from pandas import DataFrame, Series, MultiIndex from pandas.compat import range, lrange, iteritems #from pandas.core.datetools import format as date_format @@ -266,7 +266,7 @@ def _roundtrip(self): self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip') result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip') - result.set_index('pandas_index', inplace=True) + result.set_index('level_0', inplace=True) # result.index.astype(int) result.index.name = None @@ -391,7 +391,7 @@ def test_roundtrip(self): # HACK! result.index = self.test_frame1.index - result.set_index('pandas_index', inplace=True) + result.set_index('level_0', inplace=True) result.index.astype(int) result.index.name = None tm.assert_frame_equal(result, self.test_frame1) @@ -460,7 +460,9 @@ def test_date_and_index(self): issubclass(df.IntDateCol.dtype.type, np.datetime64), "IntDateCol loaded with incorrect type") + class TestSQLApi(_TestSQLApi): + """Test the public API as it would be used directly """ flavor = 'sqlite' @@ -474,10 +476,10 @@ def connect(self): def test_to_sql_index_label(self): temp_frame = DataFrame({'col1': range(4)}) - # no index name, defaults to 'pandas_index' + # no index name, defaults to 'index' sql.to_sql(temp_frame, 'test_index_label', self.conn) frame = sql.read_table('test_index_label', self.conn) - self.assertEqual(frame.columns[0], 'pandas_index') + self.assertEqual(frame.columns[0], 'index') # specifying index_label sql.to_sql(temp_frame, 'test_index_label', self.conn, @@ -487,11 +489,11 @@ def test_to_sql_index_label(self): "Specified index_label not written to database") # using the index name - temp_frame.index.name = 'index' + temp_frame.index.name = 'index_name' sql.to_sql(temp_frame, 'test_index_label', self.conn, if_exists='replace') frame = sql.read_table('test_index_label', self.conn) - self.assertEqual(frame.columns[0], 'index', + self.assertEqual(frame.columns[0], 'index_name', "Index name not written to database") # has index name, but specifying index_label @@ -501,8 +503,74 @@ def test_to_sql_index_label(self): self.assertEqual(frame.columns[0], 'other_label', "Specified index_label not written to database") + def test_to_sql_index_label_multiindex(self): + temp_frame = DataFrame({'col1': range(4)}, + index=MultiIndex.from_product([('A0', 'A1'), ('B0', 'B1')])) + + # no index name, defaults to 'level_0' and 'level_1' + sql.to_sql(temp_frame, 'test_index_label', self.conn) + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[0], 'level_0') + self.assertEqual(frame.columns[1], 'level_1') + + # specifying index_label + sql.to_sql(temp_frame, 'test_index_label', self.conn, + if_exists='replace', index_label=['A', 'B']) + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'], + "Specified index_labels not written to database") + + # using the index name + temp_frame.index.names = ['A', 'B'] + sql.to_sql(temp_frame, 'test_index_label', self.conn, + if_exists='replace') + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'], + "Index names not written to database") + + # has index name, but specifying index_label + sql.to_sql(temp_frame, 'test_index_label', self.conn, + if_exists='replace', index_label=['C', 'D']) + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'], + "Specified index_labels not written to database") + + # wrong length of index_label + self.assertRaises(ValueError, sql.to_sql, temp_frame, + 'test_index_label', self.conn, if_exists='replace', + index_label='C') + + def test_read_table_columns(self): + # test columns argument in read_table + sql.to_sql(self.test_frame1, 'test_frame', self.conn) + + cols = ['A', 'B'] + result = sql.read_table('test_frame', self.conn, columns=cols) + self.assertEqual(result.columns.tolist(), cols, + "Columns not correctly selected") + + def test_read_table_index_col(self): + # test columns argument in read_table + sql.to_sql(self.test_frame1, 'test_frame', self.conn) + + result = sql.read_table('test_frame', self.conn, index_col="index") + self.assertEqual(result.index.names, ["index"], + "index_col not correctly set") + + result = sql.read_table('test_frame', self.conn, index_col=["A", "B"]) + self.assertEqual(result.index.names, ["A", "B"], + "index_col not correctly set") + + result = sql.read_table('test_frame', self.conn, index_col=["A", "B"], + columns=["C", "D"]) + self.assertEqual(result.index.names, ["A", "B"], + "index_col not correctly set") + self.assertEqual(result.columns.tolist(), ["C", "D"], + "columns not set correctly whith index_col") + class TestSQLLegacyApi(_TestSQLApi): + """Test the public legacy API """ flavor = 'sqlite' @@ -554,6 +622,23 @@ def test_sql_open_close(self): tm.assert_frame_equal(self.test_frame2, result) + def test_roundtrip(self): + # this test otherwise fails, Legacy mode still uses 'pandas_index' + # as default index column label + sql.to_sql(self.test_frame1, 'test_frame_roundtrip', + con=self.conn, flavor='sqlite') + result = sql.read_sql( + 'SELECT * FROM test_frame_roundtrip', + con=self.conn, + flavor='sqlite') + + # HACK! + result.index = self.test_frame1.index + result.set_index('pandas_index', inplace=True) + result.index.astype(int) + result.index.name = None + tm.assert_frame_equal(result, self.test_frame1) + class _TestSQLAlchemy(PandasSQLTest): """ @@ -776,6 +861,16 @@ def setUp(self): self._load_test1_data() + def _roundtrip(self): + # overwrite parent function (level_0 -> pandas_index in legacy mode) + self.drop_table('test_frame_roundtrip') + self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip') + result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip') + result.set_index('pandas_index', inplace=True) + result.index.name = None + + tm.assert_frame_equal(result, self.test_frame1) + def test_invalid_flavor(self): self.assertRaises( NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')