Skip to content

Commit d164182

Browse files
committed
add support for specifying secondary indexes with to_sql
1 parent 2ea0601 commit d164182

File tree

3 files changed

+101
-12
lines changed

3 files changed

+101
-12
lines changed

pandas/core/generic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,8 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs):
11171117
**kwargs)
11181118

11191119
def to_sql(self, name, con, flavor='sqlite', schema=None, if_exists='fail',
1120-
index=True, index_label=None, chunksize=None, dtype=None):
1120+
index=True, index_label=None, chunksize=None, dtype=None,
1121+
indexes=None):
11211122
"""
11221123
Write records stored in a DataFrame to a SQL database.
11231124
@@ -1157,7 +1158,7 @@ def to_sql(self, name, con, flavor='sqlite', schema=None, if_exists='fail',
11571158
from pandas.io import sql
11581159
sql.to_sql(self, name, con, flavor=flavor, schema=schema,
11591160
if_exists=if_exists, index=index, index_label=index_label,
1160-
chunksize=chunksize, dtype=dtype)
1161+
chunksize=chunksize, dtype=dtype, indexes=indexes)
11611162

11621163
def to_pickle(self, path):
11631164
"""

pandas/io/sql.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,8 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None,
516516

517517

518518
def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail',
519-
index=True, index_label=None, chunksize=None, dtype=None):
519+
index=True, index_label=None, chunksize=None, dtype=None,
520+
indexes=None):
520521
"""
521522
Write records stored in a DataFrame to a SQL database.
522523
@@ -568,7 +569,7 @@ def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail',
568569

569570
pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
570571
index_label=index_label, schema=schema,
571-
chunksize=chunksize, dtype=dtype)
572+
chunksize=chunksize, dtype=dtype, indexes=indexes)
572573

573574

574575
def has_table(table_name, con, flavor='sqlite', schema=None):
@@ -653,12 +654,13 @@ class SQLTable(PandasObject):
653654

654655
def __init__(self, name, pandas_sql_engine, frame=None, index=True,
655656
if_exists='fail', prefix='pandas', index_label=None,
656-
schema=None, keys=None, dtype=None):
657+
schema=None, keys=None, dtype=None, indexes=None):
657658
self.name = name
658659
self.pd_sql = pandas_sql_engine
659660
self.prefix = prefix
660661
self.frame = frame
661662
self.index = self._index_name(index, index_label)
663+
self.indexes = indexes
662664
self.schema = schema
663665
self.if_exists = if_exists
664666
self.keys = keys
@@ -849,18 +851,33 @@ def _index_name(self, index, index_label):
849851
else:
850852
return None
851853

854+
def _is_column_indexed(self, label):
855+
if self.indexes is not None and label in self.indexes:
856+
return True
857+
858+
if self.index is not None and label in self.index:
859+
if self.keys is None:
860+
return True
861+
862+
col_nr = self.index.index(label) + 1
863+
if self.keys[:col_nr] != self.index[:col_nr]:
864+
return True
865+
866+
return False
867+
852868
def _get_column_names_and_types(self, dtype_mapper):
853869
column_names_and_types = []
854870
if self.index is not None:
855871
for i, idx_label in enumerate(self.index):
856872
idx_type = dtype_mapper(
857873
self.frame.index.get_level_values(i))
858-
column_names_and_types.append((idx_label, idx_type, True))
874+
indexed = self._is_column_indexed(idx_label)
875+
column_names_and_types.append((idx_label, idx_type, indexed))
859876

860877
column_names_and_types += [
861878
(text_type(self.frame.columns[i]),
862879
dtype_mapper(self.frame.iloc[:, i]),
863-
False)
880+
self._is_column_indexed(text_type(self.frame.columns[i])))
864881
for i in range(len(self.frame.columns))
865882
]
866883

@@ -1205,7 +1222,8 @@ def read_query(self, sql, index_col=None, coerce_float=True,
12051222
read_sql = read_query
12061223

12071224
def to_sql(self, frame, name, if_exists='fail', index=True,
1208-
index_label=None, schema=None, chunksize=None, dtype=None):
1225+
index_label=None, schema=None, chunksize=None, dtype=None,
1226+
indexes=None):
12091227
"""
12101228
Write records stored in a DataFrame to a SQL database.
12111229
@@ -1245,7 +1263,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
12451263

12461264
table = SQLTable(name, self, frame=frame, index=index,
12471265
if_exists=if_exists, index_label=index_label,
1248-
schema=schema, dtype=dtype)
1266+
schema=schema, dtype=dtype, indexes=indexes)
12491267
table.create()
12501268
table.insert(chunksize)
12511269
if (not name.isdigit() and not name.islower()):
@@ -1620,7 +1638,8 @@ def _fetchall_as_list(self, cur):
16201638
return result
16211639

16221640
def to_sql(self, frame, name, if_exists='fail', index=True,
1623-
index_label=None, schema=None, chunksize=None, dtype=None):
1641+
index_label=None, schema=None, chunksize=None, dtype=None,
1642+
indexes=None):
16241643
"""
16251644
Write records stored in a DataFrame to a SQL database.
16261645
@@ -1657,7 +1676,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
16571676

16581677
table = SQLiteTable(name, self, frame=frame, index=index,
16591678
if_exists=if_exists, index_label=index_label,
1660-
dtype=dtype)
1679+
dtype=dtype, indexes=indexes)
16611680
table.create()
16621681
table.insert(chunksize)
16631682

pandas/io/tests/test_sql.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,14 @@ def _load_test3_data(self):
309309

310310
self.test_frame3 = DataFrame(data, columns=columns)
311311

312+
def _load_test4_data(self):
313+
n = 10
314+
colors = np.random.choice(['red', 'green'], size=n)
315+
foods = np.random.choice(['eggs', 'ham'], size=n)
316+
index = pd.MultiIndex.from_arrays([colors, foods],
317+
names=['color', 'food'])
318+
self.test_frame4 = DataFrame(np.random.randn(n, 2), index=index)
319+
312320
def _load_raw_sql(self):
313321
self.drop_table('types_test_data')
314322
self._get_exec().execute(SQL_STRINGS['create_test_types'][self.flavor])
@@ -512,6 +520,7 @@ def setUp(self):
512520
self._load_test1_data()
513521
self._load_test2_data()
514522
self._load_test3_data()
523+
self._load_test4_data()
515524
self._load_raw_sql()
516525

517526
def test_read_sql_iris(self):
@@ -933,7 +942,7 @@ def test_warning_case_insensitive_table_name(self):
933942
def _get_index_columns(self, tbl_name):
934943
from sqlalchemy.engine import reflection
935944
insp = reflection.Inspector.from_engine(self.conn)
936-
ixs = insp.get_indexes('test_index_saved')
945+
ixs = insp.get_indexes(tbl_name)
937946
ixs = [i['column_names'] for i in ixs]
938947
return ixs
939948

@@ -966,6 +975,66 @@ def test_to_sql_read_sql_with_database_uri(self):
966975
tm.assert_frame_equal(test_frame1, test_frame3)
967976
tm.assert_frame_equal(test_frame1, test_frame4)
968977

978+
def test_to_sql_column_indexes(self):
979+
temp_frame = DataFrame({'col1': range(4), 'col2': range(4)})
980+
sql.to_sql(temp_frame, 'test_to_sql_column_indexes', self.conn,
981+
index=False, if_exists='replace', indexes=['col1', 'col2'])
982+
ix_cols = self._get_index_columns('test_to_sql_column_indexes')
983+
self.assertEqual(sorted(ix_cols), [['col1'], ['col2']],
984+
"columns are not correctly indexes")
985+
986+
def test_sqltable_key_and_multiindex_no_pk(self):
987+
db = sql.SQLDatabase(self.conn)
988+
table = sql.SQLTable('test_sqltable_key_and_multiindex_no_pk', db,
989+
frame=self.test_frame4, index=True)
990+
metadata = table.table.tometadata(table.pd_sql.meta)
991+
indexed_columns = [e.columns.keys() for e in metadata.indexes]
992+
primary_keys = metadata.primary_key.columns.keys()
993+
self.assertListEqual([['color'], ['food']], sorted(indexed_columns),
994+
"Wrong secondary indexes")
995+
self.assertListEqual([], primary_keys,
996+
"There should be no primary keys")
997+
998+
def test_sqltable_key_and_multiindex_one_pk(self):
999+
db = sql.SQLDatabase(self.conn)
1000+
table = sql.SQLTable('test_sqltable_key_and_multiindex_one_pk', db,
1001+
frame=self.test_frame4, index=True,
1002+
keys=['color'])
1003+
metadata = table.table.tometadata(table.pd_sql.meta)
1004+
indexed_columns = [e.columns.keys() for e in metadata.indexes]
1005+
primary_keys = metadata.primary_key.columns.keys()
1006+
self.assertListEqual([['food']], indexed_columns,
1007+
"Wrong secondary indexes")
1008+
self.assertListEqual(['color'], primary_keys,
1009+
"Wrong primary keys")
1010+
1011+
def test_sqltable_key_and_multiindex_two_pk(self):
1012+
db = sql.SQLDatabase(self.conn)
1013+
table = sql.SQLTable('test_sqltable_key_and_multiindex_two_pk', db,
1014+
frame=self.test_frame4, index=True,
1015+
keys=['color', 'food'])
1016+
metadata = table.table.tometadata(table.pd_sql.meta)
1017+
indexed_columns = [e.columns.keys() for e in metadata.indexes]
1018+
primary_keys = metadata.primary_key.columns.keys()
1019+
self.assertListEqual([], indexed_columns,
1020+
"There should be no secondary indexes")
1021+
self.assertListEqual(['color', 'food'], primary_keys,
1022+
"Wrong primary keys")
1023+
1024+
def test_sqltable_no_double_key_and_index_index(self):
1025+
temp_frame = DataFrame({'col1': range(4), 'col2': range(4)})
1026+
db = sql.SQLDatabase(self.conn)
1027+
table = sql.SQLTable('test_sqltable_no_double_key_and_index_index', db,
1028+
frame=temp_frame, index=True, index_label='id',
1029+
keys=['id'], indexes=['col1', 'col2'])
1030+
table_metadata = table.table.tometadata(table.pd_sql.meta)
1031+
indexed_columns = [e.columns.keys() for e in table_metadata.indexes]
1032+
self.assertNotIn('id', indexed_columns,
1033+
"Secondary Index found for primary key")
1034+
1035+
self.assertListEqual(['id'], table_metadata.primary_key.columns.keys(),
1036+
"Primary key missing from table")
1037+
9691038
def _make_iris_table_metadata(self):
9701039
sa = sqlalchemy
9711040
metadata = sa.MetaData()

0 commit comments

Comments
 (0)