Skip to content

Commit 5065232

Browse files
artemykArtemy Kolchinsky
authored and
Artemy Kolchinsky
committed
Rewrote/refactored get_schema to use methods from table classes (GH8232)
Fix creation of database indexes Tests fix Fixes Remove _executefunc Fixed tests to work with Python3 Schema rewriting Fixes Doc update Fixed postgres schema errors Fixes Fixes Moving if_exists checks to create Removing executescript Refactor insert differently Code review fixes Removed meta.schema from read_table Merges Merged #8208 Replacing izip with zip for Python3
1 parent 35a9527 commit 5065232

File tree

2 files changed

+106
-127
lines changed

2 files changed

+106
-127
lines changed

pandas/io/sql.py

+95-125
Original file line numberDiff line numberDiff line change
@@ -552,33 +552,19 @@ class PandasSQLTable(PandasObject):
552552
# TODO: support for multiIndex
553553
def __init__(self, name, pandas_sql_engine, frame=None, index=True,
554554
if_exists='fail', prefix='pandas', index_label=None,
555-
schema=None):
555+
schema=None, keys=None):
556556
self.name = name
557557
self.pd_sql = pandas_sql_engine
558558
self.prefix = prefix
559559
self.frame = frame
560560
self.index = self._index_name(index, index_label)
561561
self.schema = schema
562+
self.if_exists = if_exists
563+
self.keys = keys
562564

563565
if frame is not None:
564-
# We want to write a frame
565-
if self.pd_sql.has_table(self.name, self.schema):
566-
if if_exists == 'fail':
567-
raise ValueError("Table '%s' already exists." % name)
568-
elif if_exists == 'replace':
569-
self.pd_sql.drop_table(self.name, self.schema)
570-
self.table = self._create_table_setup()
571-
self.create()
572-
elif if_exists == 'append':
573-
self.table = self.pd_sql.get_table(self.name, self.schema)
574-
if self.table is None:
575-
self.table = self._create_table_setup()
576-
else:
577-
raise ValueError(
578-
"'{0}' is not valid for if_exists".format(if_exists))
579-
else:
580-
self.table = self._create_table_setup()
581-
self.create()
566+
# We want to initialize based on a dataframe
567+
self.table = self._create_table_setup()
582568
else:
583569
# no data provided, read-only mode
584570
self.table = self.pd_sql.get_table(self.name, self.schema)
@@ -593,9 +579,26 @@ def sql_schema(self):
593579
from sqlalchemy.schema import CreateTable
594580
return str(CreateTable(self.table))
595581

596-
def create(self):
582+
def _execute_create(self):
583+
# Inserting table into database, add to MetaData object
584+
self.table = self.table.tometadata(self.pd_sql.meta)
597585
self.table.create()
598586

587+
def create(self):
588+
if self.exists():
589+
if self.if_exists == 'fail':
590+
raise ValueError("Table '%s' already exists." % self.name)
591+
elif self.if_exists == 'replace':
592+
self.pd_sql.drop_table(self.name, self.schema)
593+
self._execute_create()
594+
elif self.if_exists == 'append':
595+
pass
596+
else:
597+
raise ValueError(
598+
"'{0}' is not valid for if_exists".format(self.if_exists))
599+
else:
600+
self._execute_create()
601+
599602
def insert_statement(self):
600603
return self.table.insert()
601604

@@ -634,28 +637,31 @@ def insert_data(self):
634637

635638
return column_names, data_list
636639

637-
def insert(self, chunksize=None):
640+
def get_session(self):
641+
con = self.pd_sql.engine.connect()
642+
return con.begin()
638643

639-
ins = self.insert_statement()
644+
def _execute_insert(self, trans, keys, data_iter):
645+
data = [dict( (k, v) for k, v in zip(keys, row) ) for row in data_iter]
646+
trans.connection.execute(self.insert_statement(), data)
647+
648+
def insert(self, chunksize=None):
640649
keys, data_list = self.insert_data()
641650

642651
nrows = len(self.frame)
643652
if chunksize is None:
644653
chunksize = nrows
645654
chunks = int(nrows / chunksize) + 1
646655

647-
con = self.pd_sql.engine.connect()
648-
with con.begin() as trans:
656+
with self.get_session() as trans:
649657
for i in range(chunks):
650658
start_i = i * chunksize
651659
end_i = min((i + 1) * chunksize, nrows)
652660
if start_i >= end_i:
653661
break
654662

655-
chunk_list = [arr[start_i:end_i] for arr in data_list]
656-
insert_list = [dict((k, v) for k, v in zip(keys, row))
657-
for row in zip(*chunk_list)]
658-
con.execute(ins, insert_list)
663+
chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list])
664+
self._execute_insert(trans, keys, chunk_iter)
659665

660666
def read(self, coerce_float=True, parse_dates=None, columns=None):
661667

@@ -729,15 +735,27 @@ def _get_column_names_and_types(self, dtype_mapper):
729735
return column_names_and_types
730736

731737
def _create_table_setup(self):
732-
from sqlalchemy import Table, Column
738+
from sqlalchemy import Table, Column, PrimaryKeyConstraint
733739

734740
column_names_and_types = \
735741
self._get_column_names_and_types(self._sqlalchemy_type)
736742

737743
columns = [Column(name, typ, index=is_index)
738744
for name, typ, is_index in column_names_and_types]
739745

740-
return Table(self.name, self.pd_sql.meta, *columns, schema=self.schema)
746+
if self.keys is not None:
747+
columns.append(PrimaryKeyConstraint(self.keys,
748+
name=self.name+'_pk'))
749+
750+
751+
schema = self.schema or self.pd_sql.meta.schema
752+
753+
# At this point, attach to new metadata, only attach to self.meta
754+
# once table is created.
755+
from sqlalchemy.schema import MetaData
756+
meta = MetaData(self.pd_sql, schema=schema)
757+
758+
return Table(self.name, meta, *columns, schema=schema)
741759

742760
def _harmonize_columns(self, parse_dates=None):
743761
""" Make a data_frame's column type align with an sql_table
@@ -872,7 +890,6 @@ def execute(self, *args, **kwargs):
872890

873891
def read_table(self, table_name, index_col=None, coerce_float=True,
874892
parse_dates=None, columns=None, schema=None):
875-
876893
table = PandasSQLTable(
877894
table_name, self, index=index_col, schema=schema)
878895
return table.read(coerce_float=coerce_float,
@@ -901,6 +918,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
901918
table = PandasSQLTable(
902919
name, self, frame=frame, index=index, if_exists=if_exists,
903920
index_label=index_label, schema=schema)
921+
table.create()
904922
table.insert(chunksize)
905923
# check for potentially case sensitivity issues (GH7815)
906924
if name not in self.engine.table_names(schema=schema or self.meta.schema):
@@ -930,8 +948,9 @@ def drop_table(self, table_name, schema=None):
930948
self.get_table(table_name, schema).drop()
931949
self.meta.clear()
932950

933-
def _create_sql_schema(self, frame, table_name):
934-
table = PandasSQLTable(table_name, self, frame=frame)
951+
def _create_sql_schema(self, frame, table_name, keys=None):
952+
table = PandasSQLTable(table_name, self, frame=frame, index=False,
953+
keys=keys)
935954
return str(table.sql_schema())
936955

937956

@@ -997,8 +1016,8 @@ class PandasSQLTableLegacy(PandasSQLTable):
9971016
def sql_schema(self):
9981017
return str(";\n".join(self.table))
9991018

1000-
def create(self):
1001-
with self.pd_sql.con:
1019+
def _execute_create(self):
1020+
with self.get_session():
10021021
for stmt in self.table:
10031022
self.pd_sql.execute(stmt)
10041023

@@ -1019,28 +1038,12 @@ def insert_statement(self):
10191038
self.name, col_names, wildcards)
10201039
return insert_statement
10211040

1022-
def insert(self, chunksize=None):
1023-
1024-
ins = self.insert_statement()
1025-
keys, data_list = self.insert_data()
1026-
1027-
nrows = len(self.frame)
1028-
if chunksize is None:
1029-
chunksize = nrows
1030-
chunks = int(nrows / chunksize) + 1
1041+
def get_session(self):
1042+
return self.pd_sql.con
10311043

1032-
with self.pd_sql.con:
1033-
for i in range(chunks):
1034-
start_i = i * chunksize
1035-
end_i = min((i + 1) * chunksize, nrows)
1036-
if start_i >= end_i:
1037-
break
1038-
chunk_list = [arr[start_i:end_i] for arr in data_list]
1039-
insert_list = [tuple((v for v in row))
1040-
for row in zip(*chunk_list)]
1041-
cur = self.pd_sql.con.cursor()
1042-
cur.executemany(ins, insert_list)
1043-
cur.close()
1044+
def _execute_insert(self, trans, keys, data_iter):
1045+
data_list = list(data_iter)
1046+
trans.executemany(self.insert_statement(), data_list)
10441047

10451048
def _create_table_setup(self):
10461049
"""Return a list of SQL statement that create a table reflecting the
@@ -1061,21 +1064,25 @@ def _create_table_setup(self):
10611064
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
10621065
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char
10631066

1064-
col_template = br_l + '%s' + br_r + ' %s'
1065-
1066-
columns = ',\n '.join(col_template % (cname, ctype)
1067-
for cname, ctype, _ in column_names_and_types)
1068-
template = """CREATE TABLE %(name)s (
1069-
%(columns)s
1070-
)"""
1071-
create_stmts = [template % {'name': self.name, 'columns': columns}, ]
1072-
1073-
ix_tpl = "CREATE INDEX ix_{tbl}_{col} ON {tbl} ({br_l}{col}{br_r})"
1074-
for cname, _, is_index in column_names_and_types:
1075-
if not is_index:
1076-
continue
1077-
create_stmts.append(ix_tpl.format(tbl=self.name, col=cname,
1078-
br_l=br_l, br_r=br_r))
1067+
create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, ctype)
1068+
for cname, ctype, _ in column_names_and_types]
1069+
if self.keys is not None and len(self.keys):
1070+
cnames_br = ",".join([br_l + c + br_r for c in self.keys])
1071+
create_tbl_stmts.append(
1072+
"CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format(
1073+
tbl=self.name, cnames_br=cnames_br))
1074+
1075+
create_stmts = ["CREATE TABLE " + self.name + " (\n" +
1076+
',\n '.join(create_tbl_stmts) + "\n)"]
1077+
1078+
ix_cols = [cname for cname, _, is_index in column_names_and_types
1079+
if is_index]
1080+
if len(ix_cols):
1081+
cnames = "_".join(ix_cols)
1082+
cnames_br = ",".join([br_l + c + br_r for c in ix_cols])
1083+
create_stmts.append(
1084+
"CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format(
1085+
tbl=self.name, cnames=cnames, cnames_br=cnames_br))
10791086

10801087
return create_stmts
10811088

@@ -1172,16 +1179,28 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11721179
----------
11731180
frame: DataFrame
11741181
name: name of SQL table
1175-
flavor: {'sqlite', 'mysql'}, default 'sqlite'
11761182
if_exists: {'fail', 'replace', 'append'}, default 'fail'
11771183
fail: If table exists, do nothing.
11781184
replace: If table exists, drop it, recreate it, and insert data.
11791185
append: If table exists, insert data. Create if does not exist.
1186+
index : boolean, default True
1187+
Write DataFrame index as a column
1188+
index_label : string or sequence, default None
1189+
Column label for index column(s). If None is given (default) and
1190+
`index` is True, then the index names are used.
1191+
A sequence should be given if the DataFrame uses MultiIndex.
1192+
schema : string, default None
1193+
Ignored parameter included for compatability with SQLAlchemy version
1194+
of `to_sql`.
1195+
chunksize : int, default None
1196+
If not None, then rows will be written in batches of this size at a
1197+
time. If None, all rows will be written at once.
11801198
11811199
"""
11821200
table = PandasSQLTableLegacy(
11831201
name, self, frame=frame, index=index, if_exists=if_exists,
11841202
index_label=index_label)
1203+
table.create()
11851204
table.insert(chunksize)
11861205

11871206
def has_table(self, name, schema=None):
@@ -1200,8 +1219,9 @@ def drop_table(self, name, schema=None):
12001219
drop_sql = "DROP TABLE %s" % name
12011220
self.execute(drop_sql)
12021221

1203-
def _create_sql_schema(self, frame, table_name):
1204-
table = PandasSQLTableLegacy(table_name, self, frame=frame)
1222+
def _create_sql_schema(self, frame, table_name, keys=None):
1223+
table = PandasSQLTableLegacy(table_name, self, frame=frame, index=False,
1224+
keys=keys)
12051225
return str(table.sql_schema())
12061226

12071227

@@ -1227,58 +1247,8 @@ def get_schema(frame, name, flavor='sqlite', keys=None, con=None):
12271247
12281248
"""
12291249

1230-
if con is None:
1231-
if flavor == 'mysql':
1232-
warnings.warn(_MYSQL_WARNING, FutureWarning)
1233-
return _get_schema_legacy(frame, name, flavor, keys)
1234-
12351250
pandas_sql = pandasSQL_builder(con=con, flavor=flavor)
1236-
return pandas_sql._create_sql_schema(frame, name)
1237-
1238-
1239-
def _get_schema_legacy(frame, name, flavor, keys=None):
1240-
"""Old function from 0.13.1. To keep backwards compatibility.
1241-
When mysql legacy support is dropped, it should be possible to
1242-
remove this code
1243-
"""
1244-
1245-
def get_sqltype(dtype, flavor):
1246-
pytype = dtype.type
1247-
pytype_name = "text"
1248-
if issubclass(pytype, np.floating):
1249-
pytype_name = "float"
1250-
elif issubclass(pytype, np.integer):
1251-
pytype_name = "int"
1252-
elif issubclass(pytype, np.datetime64) or pytype is datetime:
1253-
# Caution: np.datetime64 is also a subclass of np.number.
1254-
pytype_name = "datetime"
1255-
elif pytype is datetime.date:
1256-
pytype_name = "date"
1257-
elif issubclass(pytype, np.bool_):
1258-
pytype_name = "bool"
1259-
1260-
return _SQL_TYPES[pytype_name][flavor]
1261-
1262-
lookup_type = lambda dtype: get_sqltype(dtype, flavor)
1263-
1264-
column_types = lzip(frame.dtypes.index, map(lookup_type, frame.dtypes))
1265-
if flavor == 'sqlite':
1266-
columns = ',\n '.join('[%s] %s' % x for x in column_types)
1267-
else:
1268-
columns = ',\n '.join('`%s` %s' % x for x in column_types)
1269-
1270-
keystr = ''
1271-
if keys is not None:
1272-
if isinstance(keys, string_types):
1273-
keys = (keys,)
1274-
keystr = ', PRIMARY KEY (%s)' % ','.join(keys)
1275-
template = """CREATE TABLE %(name)s (
1276-
%(columns)s
1277-
%(keystr)s
1278-
);"""
1279-
create_statement = template % {'name': name, 'columns': columns,
1280-
'keystr': keystr}
1281-
return create_statement
1251+
return pandas_sql._create_sql_schema(frame, name, keys=keys)
12821252

12831253

12841254
# legacy names, with depreciation warnings and copied docs

pandas/io/tests/test_sql.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,15 @@ def _get_index_columns(self, tbl_name):
14491449
def test_to_sql_save_index(self):
14501450
self._to_sql_save_index()
14511451

1452+
for ix_name, ix_col in zip(ixs.Key_name, ixs.Column_name):
1453+
if ix_name not in ix_cols:
1454+
ix_cols[ix_name] = []
1455+
ix_cols[ix_name].append(ix_col)
1456+
return ix_cols.values()
1457+
1458+
def test_to_sql_save_index(self):
1459+
self._to_sql_save_index()
1460+
14521461

14531462
#------------------------------------------------------------------------------
14541463
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)
@@ -1545,7 +1554,7 @@ def test_schema(self):
15451554
frame = tm.makeTimeDataFrame()
15461555
create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],)
15471556
lines = create_sql.splitlines()
1548-
self.assertTrue('PRIMARY KEY (A,B)' in create_sql)
1557+
self.assertTrue('PRIMARY KEY ([A],[B])' in create_sql)
15491558
cur = self.db.cursor()
15501559
cur.execute(create_sql)
15511560

@@ -1824,7 +1833,7 @@ def test_schema(self):
18241833
drop_sql = "DROP TABLE IF EXISTS test"
18251834
create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],)
18261835
lines = create_sql.splitlines()
1827-
self.assertTrue('PRIMARY KEY (A,B)' in create_sql)
1836+
self.assertTrue('PRIMARY KEY (`A`,`B`)' in create_sql)
18281837
cur = self.db.cursor()
18291838
cur.execute(drop_sql)
18301839
cur.execute(create_sql)

0 commit comments

Comments
 (0)