Skip to content

Commit fce019b

Browse files
committed
ENH: 'to_sql()' add param 'method' to control insert statement (pandas-dev#21103)
Also revert default insert method to NOT use multi-value.
1 parent 3147a86 commit fce019b

File tree

2 files changed

+61
-30
lines changed

2 files changed

+61
-30
lines changed

pandas/core/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2014,7 +2014,7 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs):
20142014
**kwargs)
20152015

20162016
def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
2017-
index_label=None, chunksize=None, dtype=None):
2017+
index_label=None, chunksize=None, dtype=None, method=None):
20182018
"""
20192019
Write records stored in a DataFrame to a SQL database.
20202020
@@ -2124,7 +2124,7 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
21242124
from pandas.io import sql
21252125
sql.to_sql(self, name, con, schema=schema, if_exists=if_exists,
21262126
index=index, index_label=index_label, chunksize=chunksize,
2127-
dtype=dtype)
2127+
dtype=dtype, method=method)
21282128

21292129
def to_pickle(self, path, compression='infer',
21302130
protocol=pkl.HIGHEST_PROTOCOL):

pandas/io/sql.py

+59-28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import print_function, division
88
from datetime import datetime, date, time
9+
import csv
10+
from io import StringIO
911

1012
import warnings
1113
import re
@@ -398,7 +400,7 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None,
398400

399401

400402
def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
401-
index_label=None, chunksize=None, dtype=None):
403+
index_label=None, chunksize=None, dtype=None, method=None):
402404
"""
403405
Write records stored in a DataFrame to a SQL database.
404406
@@ -447,7 +449,7 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
447449

448450
pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
449451
index_label=index_label, schema=schema,
450-
chunksize=chunksize, dtype=dtype)
452+
chunksize=chunksize, dtype=dtype, method=method)
451453

452454

453455
def has_table(table_name, con, schema=None):
@@ -572,29 +574,47 @@ def create(self):
572574
else:
573575
self._execute_create()
574576

575-
def insert_statement(self, data, conn):
576-
"""
577-
Generate tuple of SQLAlchemy insert statement and any arguments
578-
to be executed by connection (via `_execute_insert`).
577+
def _exec_insert(self, conn, keys, data_iter):
578+
"""Execute SQL statement inserting data
579579
580580
Parameters
581581
----------
582-
conn : SQLAlchemy connectable(engine/connection)
583-
Connection to recieve the data
584-
data : list of dict
585-
The data to be inserted
582+
data : list of list
583+
of values to be inserted
584+
"""
585+
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
586+
conn.execute(self.table.insert(), data)
586587

587-
Returns
588-
-------
589-
SQLAlchemy statement
590-
insert statement
591-
*, optional
592-
Additional parameters to be passed when executing insert statement
588+
def _exec_insert_multi(self, conn, keys, data_iter):
589+
"""Alternative to _exec_insert for DBs that support multivalue INSERT.
590+
591+
Note: multi-value insert is usually faster for a few columns
592+
but performance degrades quickly with increase of columns.
593593
"""
594-
dialect = getattr(conn, 'dialect', None)
595-
if dialect and getattr(dialect, 'supports_multivalues_insert', False):
596-
return self.table.insert(data),
597-
return self.table.insert(), data
594+
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
595+
conn.execute(self.table.insert(data))
596+
597+
def _exec_insert_copy(self, conn, keys, data_iter):
598+
"""Alternative to _exec_insert for DBs that support COPY FROM
599+
"""
600+
# gets a DBAPI connection that can provide a cursor
601+
dbapi_conn = conn.connection
602+
with dbapi_conn.cursor() as cur:
603+
s_buf = StringIO()
604+
writer = csv.writer(s_buf)
605+
writer.writerows(data_iter)
606+
s_buf.seek(0)
607+
608+
columns = ', '.join('"{}"'.format(k) for k in keys)
609+
if self.schema:
610+
table_name = '{}.{}'.format(self.schema, self.name)
611+
else:
612+
table_name = self.name
613+
614+
sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format(
615+
table_name, columns)
616+
cur.copy_expert(sql=sql, file=s_buf)
617+
598618

599619
def insert_data(self):
600620
if self.index is not None:
@@ -632,12 +652,20 @@ def insert_data(self):
632652

633653
return column_names, data_list
634654

635-
def _execute_insert(self, conn, keys, data_iter):
636-
"""Insert data into this table with database connection"""
637-
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
638-
conn.execute(*self.insert_statement(data, conn))
639655

640-
def insert(self, chunksize=None):
656+
def insert(self, chunksize=None, method=None):
657+
658+
# set insert method
659+
if method in (None, 'default'):
660+
exec_insert = self._exec_insert
661+
elif method == 'multi':
662+
exec_insert = self._exec_insert_multi
663+
elif method == 'copy':
664+
exec_insert = self._exec_insert_copy
665+
else:
666+
# TODO: support callables?
667+
raise ValueError('Invalid parameter `method`: {}'.format(method))
668+
641669
keys, data_list = self.insert_data()
642670

643671
nrows = len(self.frame)
@@ -660,7 +688,9 @@ def insert(self, chunksize=None):
660688
break
661689

662690
chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list])
663-
self._execute_insert(conn, keys, chunk_iter)
691+
exec_insert(conn, keys, chunk_iter)
692+
693+
664694

665695
def _query_iterator(self, result, chunksize, columns, coerce_float=True,
666696
parse_dates=None):
@@ -1100,7 +1130,8 @@ def read_query(self, sql, index_col=None, coerce_float=True,
11001130
read_sql = read_query
11011131

11021132
def to_sql(self, frame, name, if_exists='fail', index=True,
1103-
index_label=None, schema=None, chunksize=None, dtype=None):
1133+
index_label=None, schema=None, chunksize=None, dtype=None,
1134+
method=None):
11041135
"""
11051136
Write records stored in a DataFrame to a SQL database.
11061137
@@ -1146,7 +1177,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11461177
if_exists=if_exists, index_label=index_label,
11471178
schema=schema, dtype=dtype)
11481179
table.create()
1149-
table.insert(chunksize)
1180+
table.insert(chunksize, method=method)
11501181
if (not name.isdigit() and not name.islower()):
11511182
# check for potentially case sensitivity issues (GH7815)
11521183
# Only check when name is not a number and name is not lower case

0 commit comments

Comments
 (0)