Skip to content

Commit 54678dd

Browse files
Merge pull request #8083 from artemyk/to_sql_create_indexes
BUG: When creating table, db indexes should be created from DataFrame indexes
2 parents 77d5f04 + df48524 commit 54678dd

File tree

2 files changed

+84
-18
lines changed

2 files changed

+84
-18
lines changed

pandas/io/sql.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -566,17 +566,17 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True,
566566
raise ValueError("Table '%s' already exists." % name)
567567
elif if_exists == 'replace':
568568
self.pd_sql.drop_table(self.name, self.schema)
569-
self.table = self._create_table_statement()
569+
self.table = self._create_table_setup()
570570
self.create()
571571
elif if_exists == 'append':
572572
self.table = self.pd_sql.get_table(self.name, self.schema)
573573
if self.table is None:
574-
self.table = self._create_table_statement()
574+
self.table = self._create_table_setup()
575575
else:
576576
raise ValueError(
577577
"'{0}' is not valid for if_exists".format(if_exists))
578578
else:
579-
self.table = self._create_table_statement()
579+
self.table = self._create_table_setup()
580580
self.create()
581581
else:
582582
# no data provided, read-only mode
@@ -703,23 +703,25 @@ def _get_column_names_and_types(self, dtype_mapper):
703703
for i, idx_label in enumerate(self.index):
704704
idx_type = dtype_mapper(
705705
self.frame.index.get_level_values(i))
706-
column_names_and_types.append((idx_label, idx_type))
706+
column_names_and_types.append((idx_label, idx_type, True))
707707

708708
column_names_and_types += [
709709
(str(self.frame.columns[i]),
710-
dtype_mapper(self.frame.iloc[:,i]))
710+
dtype_mapper(self.frame.iloc[:,i]),
711+
False)
711712
for i in range(len(self.frame.columns))
712713
]
714+
713715
return column_names_and_types
714716

715-
def _create_table_statement(self):
717+
def _create_table_setup(self):
716718
from sqlalchemy import Table, Column
717719

718720
column_names_and_types = \
719721
self._get_column_names_and_types(self._sqlalchemy_type)
720722

721-
columns = [Column(name, typ)
722-
for name, typ in column_names_and_types]
723+
columns = [Column(name, typ, index=is_index)
724+
for name, typ, is_index in column_names_and_types]
723725

724726
return Table(self.name, self.pd_sql.meta, *columns, schema=self.schema)
725727

@@ -979,10 +981,12 @@ class PandasSQLTableLegacy(PandasSQLTable):
979981
Instead of a table variable just use the Create Table
980982
statement"""
981983
def sql_schema(self):
982-
return str(self.table)
984+
return str(";\n".join(self.table))
983985

984986
def create(self):
985-
self.pd_sql.execute(self.table)
987+
with self.pd_sql.con:
988+
for stmt in self.table:
989+
self.pd_sql.execute(stmt)
986990

987991
def insert_statement(self):
988992
names = list(map(str, self.frame.columns))
@@ -1026,14 +1030,17 @@ def insert(self, chunksize=None):
10261030
cur.executemany(ins, data_list)
10271031
cur.close()
10281032

1029-
def _create_table_statement(self):
1030-
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
1033+
def _create_table_setup(self):
1034+
"""Return a list of SQL statement that create a table reflecting the
1035+
structure of a DataFrame. The first entry will be a CREATE TABLE
1036+
statement while the rest will be CREATE INDEX statements
1037+
"""
10311038

10321039
column_names_and_types = \
10331040
self._get_column_names_and_types(self._sql_type_name)
10341041

10351042
pat = re.compile('\s+')
1036-
column_names = [col_name for col_name, _ in column_names_and_types]
1043+
column_names = [col_name for col_name, _, _ in column_names_and_types]
10371044
if any(map(pat.search, column_names)):
10381045
warnings.warn(_SAFE_NAMES_WARNING)
10391046

@@ -1044,13 +1051,21 @@ def _create_table_statement(self):
10441051

10451052
col_template = br_l + '%s' + br_r + ' %s'
10461053

1047-
columns = ',\n '.join(col_template %
1048-
x for x in column_names_and_types)
1054+
columns = ',\n '.join(col_template % (cname, ctype)
1055+
for cname, ctype, _ in column_names_and_types)
10491056
template = """CREATE TABLE %(name)s (
10501057
%(columns)s
10511058
)"""
1052-
create_statement = template % {'name': self.name, 'columns': columns}
1053-
return create_statement
1059+
create_stmts = [template % {'name': self.name, 'columns': columns}, ]
1060+
1061+
ix_tpl = "CREATE INDEX ix_{tbl}_{col} ON {tbl} ({br_l}{col}{br_r})"
1062+
for cname, _, is_index in column_names_and_types:
1063+
if not is_index:
1064+
continue
1065+
create_stmts.append(ix_tpl.format(tbl=self.name, col=cname,
1066+
br_l=br_l, br_r=br_r))
1067+
1068+
return create_stmts
10541069

10551070
def _sql_type_name(self, col):
10561071
pytype = col.dtype.type

pandas/io/tests/test_sql.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _load_test2_data(self):
199199
E=['1990-11-22', '1991-10-26', '1993-11-26', '1995-12-12']))
200200
df['E'] = to_datetime(df['E'])
201201

202-
self.test_frame3 = df
202+
self.test_frame2 = df
203203

204204
def _load_test3_data(self):
205205
columns = ['index', 'A', 'B']
@@ -324,6 +324,13 @@ def _execute_sql(self):
324324
row = iris_results.fetchone()
325325
tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa'])
326326

327+
def _to_sql_save_index(self):
328+
df = DataFrame.from_records([(1,2.1,'line1'), (2,1.5,'line2')],
329+
columns=['A','B','C'], index=['A'])
330+
self.pandasSQL.to_sql(df, 'test_to_sql_saves_index')
331+
ix_cols = self._get_index_columns('test_to_sql_saves_index')
332+
self.assertEqual(ix_cols, [['A',],])
333+
327334

328335
#------------------------------------------------------------------------------
329336
#--- Testing the public API
@@ -694,6 +701,13 @@ def test_warning_case_insensitive_table_name(self):
694701
# Verify some things
695702
self.assertEqual(len(w), 0, "Warning triggered for writing a table")
696703

704+
def _get_index_columns(self, tbl_name):
705+
from sqlalchemy.engine import reflection
706+
insp = reflection.Inspector.from_engine(self.conn)
707+
ixs = insp.get_indexes('test_index_saved')
708+
ixs = [i['column_names'] for i in ixs]
709+
return ixs
710+
697711

698712
class TestSQLLegacyApi(_TestSQLApi):
699713
"""
@@ -1074,6 +1088,16 @@ def test_nan_string(self):
10741088
result = sql.read_sql_query('SELECT * FROM test_nan', self.conn)
10751089
tm.assert_frame_equal(result, df)
10761090

1091+
def _get_index_columns(self, tbl_name):
1092+
from sqlalchemy.engine import reflection
1093+
insp = reflection.Inspector.from_engine(self.conn)
1094+
ixs = insp.get_indexes(tbl_name)
1095+
ixs = [i['column_names'] for i in ixs]
1096+
return ixs
1097+
1098+
def test_to_sql_save_index(self):
1099+
self._to_sql_save_index()
1100+
10771101

10781102
class TestSQLiteAlchemy(_TestSQLAlchemy):
10791103
"""
@@ -1368,6 +1392,20 @@ def test_datetime_time(self):
13681392
# test support for datetime.time
13691393
raise nose.SkipTest("datetime.time not supported for sqlite fallback")
13701394

1395+
def _get_index_columns(self, tbl_name):
1396+
ixs = sql.read_sql_query(
1397+
"SELECT * FROM sqlite_master WHERE type = 'index' " +
1398+
"AND tbl_name = '%s'" % tbl_name, self.conn)
1399+
ix_cols = []
1400+
for ix_name in ixs.name:
1401+
ix_info = sql.read_sql_query(
1402+
"PRAGMA index_info(%s)" % ix_name, self.conn)
1403+
ix_cols.append(ix_info.name.tolist())
1404+
return ix_cols
1405+
1406+
def test_to_sql_save_index(self):
1407+
self._to_sql_save_index()
1408+
13711409

13721410
class TestMySQLLegacy(TestSQLiteLegacy):
13731411
"""
@@ -1424,6 +1462,19 @@ def test_a_deprecation(self):
14241462
sql.has_table('test_frame1', self.conn, flavor='mysql'),
14251463
'Table not written to DB')
14261464

1465+
def _get_index_columns(self, tbl_name):
1466+
ixs = sql.read_sql_query(
1467+
"SHOW INDEX IN %s" % tbl_name, self.conn)
1468+
ix_cols = {}
1469+
for ix_name, ix_col in zip(ixs.Key_name, ixs.Column_name):
1470+
if ix_name not in ix_cols:
1471+
ix_cols[ix_name] = []
1472+
ix_cols[ix_name].append(ix_col)
1473+
return list(ix_cols.values())
1474+
1475+
def test_to_sql_save_index(self):
1476+
self._to_sql_save_index()
1477+
14271478

14281479
#------------------------------------------------------------------------------
14291480
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)

0 commit comments

Comments
 (0)