Skip to content

Commit a77cdfd

Browse files
committed
ENH: to_sql() add parameter "method" to control insertions method (pandas-dev#8953)
1 parent abfac97 commit a77cdfd

File tree

4 files changed

+131
-20
lines changed

4 files changed

+131
-20
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Other Enhancements
1616
- :func:`Series.mode` and :func:`DataFrame.mode` now support the ``dropna`` parameter which can be used to specify whether NaN/NaT values should be considered (:issue:`17534`)
1717
- :func:`to_csv` now supports ``compression`` keyword when a file handle is passed. (:issue:`21227`)
1818
- :meth:`Index.droplevel` is now implemented also for flat indexes, for compatibility with MultiIndex (:issue:`21115`)
19+
- :func:`~pandas.DataFrame.to_sql` add parameter ``method`` to control SQL insertion clause (:8953:)
1920

2021

2122
.. _whatsnew_0240.api_breaking:

pandas/core/generic.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -2014,7 +2014,7 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs):
20142014
**kwargs)
20152015

20162016
def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
2017-
index_label=None, chunksize=None, dtype=None):
2017+
index_label=None, chunksize=None, dtype=None, method='default'):
20182018
"""
20192019
Write records stored in a DataFrame to a SQL database.
20202020
@@ -2052,6 +2052,8 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
20522052
Specifying the datatype for columns. The keys should be the column
20532053
names and the values should be the SQLAlchemy types or strings for
20542054
the sqlite3 legacy mode.
2055+
method : {'default', 'multi', callable}, default 'default'
2056+
Controls the SQL insertion clause used.
20552057
20562058
Raises
20572059
------
@@ -2120,11 +2122,59 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True,
21202122
21212123
>>> engine.execute("SELECT * FROM integers").fetchall()
21222124
[(1,), (None,), (2,)]
2125+
2126+
Insertion method:
2127+
2128+
.. versionadded:: 0.24.0
2129+
2130+
The parameter ``method`` controls the SQL insertion clause used.
2131+
Possible values are:
2132+
2133+
- `'default'`: Uses standard SQL `INSERT` clause
2134+
- `'multi'`: Pass multiple values in a single `INSERT` clause.
2135+
It uses a **special** SQL syntax not supported by all backends.
2136+
This usually provides a big performance for Analytic databases
2137+
like *Presto* and *Redshit*, but has worse performance for
2138+
traditional SQL backend if the table contains many columns.
2139+
For more information check SQLAlchemy `documention <http://docs.sqlalchemy.org/en/latest/core/dml.html?highlight=multivalues#sqlalchemy.sql.expression.Insert.values.params.*args>`__.
2140+
- callable: with signature `(pd_table, conn, keys, data_iter)`.
2141+
This can be used to implement more performant insertion based on
2142+
specific backend dialect features.
2143+
I.e. using *Postgresql* `COPY clause
2144+
<https://www.postgresql.org/docs/current/static/sql-copy.html>`__.
2145+
Check API for details and a sample implementation
2146+
:func:`~pandas.DataFrame.to_sql`.
2147+
2148+
2149+
Example of callable for Postgresql *COPY*::
2150+
2151+
# Alternative to_sql() *method* for DBs that support COPY FROM
2152+
import csv
2153+
from io import StringIO
2154+
2155+
def psql_insert_copy(table, conn, keys, data_iter):
2156+
# gets a DBAPI connection that can provide a cursor
2157+
dbapi_conn = conn.connection
2158+
with dbapi_conn.cursor() as cur:
2159+
s_buf = StringIO()
2160+
writer = csv.writer(s_buf)
2161+
writer.writerows(data_iter)
2162+
s_buf.seek(0)
2163+
2164+
columns = ', '.join('"{}"'.format(k) for k in keys)
2165+
if table.schema:
2166+
table_name = '{}.{}'.format(table.schema, table.name)
2167+
else:
2168+
table_name = table.name
2169+
2170+
sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format(
2171+
table_name, columns)
2172+
cur.copy_expert(sql=sql, file=s_buf)
21232173
"""
21242174
from pandas.io import sql
21252175
sql.to_sql(self, name, con, schema=schema, if_exists=if_exists,
21262176
index=index, index_label=index_label, chunksize=chunksize,
2127-
dtype=dtype)
2177+
dtype=dtype, method=method)
21282178

21292179
def to_pickle(self, path, compression='infer',
21302180
protocol=pkl.HIGHEST_PROTOCOL):

pandas/io/sql.py

+46-15
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import print_function, division
88
from datetime import datetime, date, time
9+
from functools import partial
910

1011
import warnings
1112
import re
@@ -398,7 +399,7 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None,
398399

399400

400401
def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
401-
index_label=None, chunksize=None, dtype=None):
402+
index_label=None, chunksize=None, dtype=None, method='default'):
402403
"""
403404
Write records stored in a DataFrame to a SQL database.
404405
@@ -432,6 +433,8 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
432433
Optional specifying the datatype for columns. The SQL type should
433434
be a SQLAlchemy type, or a string for sqlite3 fallback connection.
434435
If all columns are of the same type, one single value can be used.
436+
method : {'default', 'multi', callable}, default 'default'
437+
Controls the SQL insertion clause used.
435438
436439
"""
437440
if if_exists not in ('fail', 'replace', 'append'):
@@ -447,7 +450,7 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True,
447450

448451
pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
449452
index_label=index_label, schema=schema,
450-
chunksize=chunksize, dtype=dtype)
453+
chunksize=chunksize, dtype=dtype, method=method)
451454

452455

453456
def has_table(table_name, con, schema=None):
@@ -572,8 +575,25 @@ def create(self):
572575
else:
573576
self._execute_create()
574577

575-
def insert_statement(self):
576-
return self.table.insert()
578+
def _execute_insert(self, conn, keys, data_iter):
579+
"""Execute SQL statement inserting data
580+
581+
Parameters
582+
----------
583+
data : list of list
584+
of values to be inserted
585+
"""
586+
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
587+
conn.execute(self.table.insert(), data)
588+
589+
def _execute_insert_multi(self, conn, keys, data_iter):
590+
"""Alternative to _exec_insert for DBs that support multivalue INSERT.
591+
592+
Note: multi-value insert is usually faster for a few columns
593+
but performance degrades quickly with increase of columns.
594+
"""
595+
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
596+
conn.execute(self.table.insert(data))
577597

578598
def insert_data(self):
579599
if self.index is not None:
@@ -611,11 +631,18 @@ def insert_data(self):
611631

612632
return column_names, data_list
613633

614-
def _execute_insert(self, conn, keys, data_iter):
615-
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
616-
conn.execute(self.insert_statement(), data)
634+
def insert(self, chunksize=None, method=None):
635+
636+
# set insert method
637+
if method in (None, 'default'):
638+
exec_insert = self._execute_insert
639+
elif method == 'multi':
640+
exec_insert = self._execute_insert_multi
641+
elif callable(method):
642+
exec_insert = partial(method, self)
643+
else:
644+
raise ValueError('Invalid parameter `method`: {}'.format(method))
617645

618-
def insert(self, chunksize=None):
619646
keys, data_list = self.insert_data()
620647

621648
nrows = len(self.frame)
@@ -638,7 +665,7 @@ def insert(self, chunksize=None):
638665
break
639666

640667
chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list])
641-
self._execute_insert(conn, keys, chunk_iter)
668+
exec_insert(conn, keys, chunk_iter)
642669

643670
def _query_iterator(self, result, chunksize, columns, coerce_float=True,
644671
parse_dates=None):
@@ -1078,7 +1105,8 @@ def read_query(self, sql, index_col=None, coerce_float=True,
10781105
read_sql = read_query
10791106

10801107
def to_sql(self, frame, name, if_exists='fail', index=True,
1081-
index_label=None, schema=None, chunksize=None, dtype=None):
1108+
index_label=None, schema=None, chunksize=None, dtype=None,
1109+
method='default'):
10821110
"""
10831111
Write records stored in a DataFrame to a SQL database.
10841112
@@ -1108,7 +1136,8 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11081136
Optional specifying the datatype for columns. The SQL type should
11091137
be a SQLAlchemy type. If all columns are of the same type, one
11101138
single value can be used.
1111-
1139+
method : {'default', 'multi', callable}, default 'default'
1140+
Controls the SQL insertion clause used.
11121141
"""
11131142
if dtype and not is_dict_like(dtype):
11141143
dtype = {col_name: dtype for col_name in frame}
@@ -1124,7 +1153,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11241153
if_exists=if_exists, index_label=index_label,
11251154
schema=schema, dtype=dtype)
11261155
table.create()
1127-
table.insert(chunksize)
1156+
table.insert(chunksize, method=method)
11281157
if (not name.isdigit() and not name.islower()):
11291158
# check for potentially case sensitivity issues (GH7815)
11301159
# Only check when name is not a number and name is not lower case
@@ -1434,7 +1463,8 @@ def _fetchall_as_list(self, cur):
14341463
return result
14351464

14361465
def to_sql(self, frame, name, if_exists='fail', index=True,
1437-
index_label=None, schema=None, chunksize=None, dtype=None):
1466+
index_label=None, schema=None, chunksize=None, dtype=None,
1467+
method='default'):
14381468
"""
14391469
Write records stored in a DataFrame to a SQL database.
14401470
@@ -1463,7 +1493,8 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
14631493
Optional specifying the datatype for columns. The SQL type should
14641494
be a string. If all columns are of the same type, one single value
14651495
can be used.
1466-
1496+
method : {'default', 'multi', callable}, default 'default'
1497+
Controls the SQL insertion clause used.
14671498
"""
14681499
if dtype and not is_dict_like(dtype):
14691500
dtype = {col_name: dtype for col_name in frame}
@@ -1478,7 +1509,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
14781509
if_exists=if_exists, index_label=index_label,
14791510
dtype=dtype)
14801511
table.create()
1481-
table.insert(chunksize)
1512+
table.insert(chunksize, method)
14821513

14831514
def has_table(self, name, schema=None):
14841515
# TODO(wesm): unused?

pandas/tests/io/test_sql.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,16 @@ def _read_sql_iris_named_parameter(self):
372372
iris_frame = self.pandasSQL.read_query(query, params=params)
373373
self._check_iris_loaded_frame(iris_frame)
374374

375-
def _to_sql(self):
375+
def _to_sql(self, method=None):
376376
self.drop_table('test_frame1')
377377

378-
self.pandasSQL.to_sql(self.test_frame1, 'test_frame1')
378+
self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=method)
379379
assert self.pandasSQL.has_table('test_frame1')
380380

381+
num_entries = len(self.test_frame1)
382+
num_rows = self._count_rows('test_frame1')
383+
assert num_rows == num_entries
384+
381385
# Nuke table
382386
self.drop_table('test_frame1')
383387

@@ -431,6 +435,25 @@ def _to_sql_append(self):
431435
assert num_rows == num_entries
432436
self.drop_table('test_frame1')
433437

438+
def _to_sql_method_callable(self):
439+
check = [] # used to double check function below is really being used
440+
441+
def sample(pd_table, conn, keys, data_iter):
442+
check.append(1)
443+
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
444+
conn.execute(pd_table.table.insert(), data)
445+
self.drop_table('test_frame1')
446+
447+
self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=sample)
448+
assert self.pandasSQL.has_table('test_frame1')
449+
450+
assert check == [1]
451+
num_entries = len(self.test_frame1)
452+
num_rows = self._count_rows('test_frame1')
453+
assert num_rows == num_entries
454+
# Nuke table
455+
self.drop_table('test_frame1')
456+
434457
def _roundtrip(self):
435458
self.drop_table('test_frame_roundtrip')
436459
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
@@ -1180,7 +1203,7 @@ def setup_connect(self):
11801203
pytest.skip(
11811204
"Can't connect to {0} server".format(self.flavor))
11821205

1183-
def test_aread_sql(self):
1206+
def test_read_sql(self):
11841207
self._read_sql_iris()
11851208

11861209
def test_read_sql_parameter(self):
@@ -1204,6 +1227,12 @@ def test_to_sql_replace(self):
12041227
def test_to_sql_append(self):
12051228
self._to_sql_append()
12061229

1230+
def test_to_sql_method_multi(self):
1231+
self._to_sql(method='multi')
1232+
1233+
def test_to_sql_method_callable(self):
1234+
self._to_sql_method_callable()
1235+
12071236
def test_create_table(self):
12081237
temp_conn = self.connect()
12091238
temp_frame = DataFrame(

0 commit comments

Comments
 (0)