Skip to content

Commit dd670e1

Browse files
Merge pull request #8926 from tiagoantao/master
ENH: dtype costumization on to_sql (GH8778)
2 parents 2063c1f + 5c9058b commit dd670e1

File tree

5 files changed

+97
-13
lines changed

5 files changed

+97
-13
lines changed

doc/source/io.rst

+8
Original file line numberDiff line numberDiff line change
@@ -3413,6 +3413,14 @@ With some databases, writing large DataFrames can result in errors due to packet
34133413
Because of this, reading the database table back in does **not** generate
34143414
a categorical.
34153415

3416+
.. note::
3417+
3418+
You can specify the SQL type of any of the columns by using the dtypes
3419+
parameter (a dictionary mapping column names to SQLAlchemy types). This
3420+
can be useful in cases where columns with NULL values are inferred by
3421+
Pandas to an excessively general datatype (e.g. a boolean column is is
3422+
inferred to be object because it has NULLs).
3423+
34163424

34173425
Reading Tables
34183426
~~~~~~~~~~~~~~

doc/source/whatsnew/v0.15.2.txt

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ API changes
6464
Enhancements
6565
~~~~~~~~~~~~
6666

67+
- Added the ability to specify the SQL type of columns when writing a DataFrame to a database (:issue:`8778`).
6768
- Added ability to export Categorical data to Stata (:issue:`8633`). See :ref:`here <io.stata-categorical>` for limitations of categorical variables exported to Stata data files.
6869
- Added ability to export Categorical data to to/from HDF5 (:issue:`7621`). Queries work the same as if it was an object array. However, the ``category`` dtyped data is stored in a more efficient manner. See :ref:`here <io.hdf5-categorical>` for an example and caveats w.r.t. prior versions of pandas.
6970
- Added support for ``utcfromtimestamp()``, ``fromtimestamp()``, and ``combine()`` on `Timestamp` class (:issue:`5351`).

pandas/core/generic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ def to_msgpack(self, path_or_buf=None, **kwargs):
922922
return packers.to_msgpack(path_or_buf, self, **kwargs)
923923

924924
def to_sql(self, name, con, flavor='sqlite', schema=None, if_exists='fail',
925-
index=True, index_label=None, chunksize=None):
925+
index=True, index_label=None, chunksize=None, dtype=None):
926926
"""
927927
Write records stored in a DataFrame to a SQL database.
928928
@@ -954,12 +954,15 @@ def to_sql(self, name, con, flavor='sqlite', schema=None, if_exists='fail',
954954
chunksize : int, default None
955955
If not None, then rows will be written in batches of this size at a
956956
time. If None, all rows will be written at once.
957+
dtype : Dictionary of column name to SQLAlchemy type, default None
958+
Optional datatypes for SQL columns.
957959
958960
"""
959961
from pandas.io import sql
960962
sql.to_sql(
961963
self, name, con, flavor=flavor, schema=schema, if_exists=if_exists,
962-
index=index, index_label=index_label, chunksize=chunksize)
964+
index=index, index_label=index_label, chunksize=chunksize,
965+
dtype=dtype)
963966

964967
def to_pickle(self, path):
965968
"""

pandas/io/sql.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None,
484484

485485

486486
def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail',
487-
index=True, index_label=None, chunksize=None):
487+
index=True, index_label=None, chunksize=None, dtype=None):
488488
"""
489489
Write records stored in a DataFrame to a SQL database.
490490
@@ -517,6 +517,8 @@ def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail',
517517
chunksize : int, default None
518518
If not None, then rows will be written in batches of this size at a
519519
time. If None, all rows will be written at once.
520+
dtype : dictionary of column name to SQLAchemy type, default None
521+
optional datatypes for SQL columns.
520522
521523
"""
522524
if if_exists not in ('fail', 'replace', 'append'):
@@ -531,7 +533,7 @@ def to_sql(frame, name, con, flavor='sqlite', schema=None, if_exists='fail',
531533

532534
pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
533535
index_label=index_label, schema=schema,
534-
chunksize=chunksize)
536+
chunksize=chunksize, dtype=dtype)
535537

536538

537539
def has_table(table_name, con, flavor='sqlite', schema=None):
@@ -596,7 +598,7 @@ class SQLTable(PandasObject):
596598
# TODO: support for multiIndex
597599
def __init__(self, name, pandas_sql_engine, frame=None, index=True,
598600
if_exists='fail', prefix='pandas', index_label=None,
599-
schema=None, keys=None):
601+
schema=None, keys=None, dtype=None):
600602
self.name = name
601603
self.pd_sql = pandas_sql_engine
602604
self.prefix = prefix
@@ -605,6 +607,7 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True,
605607
self.schema = schema
606608
self.if_exists = if_exists
607609
self.keys = keys
610+
self.dtype = dtype
608611

609612
if frame is not None:
610613
# We want to initialize based on a dataframe
@@ -885,6 +888,10 @@ def _sqlalchemy_type(self, col):
885888
from sqlalchemy.types import (BigInteger, Float, Text, Boolean,
886889
DateTime, Date, Time)
887890

891+
dtype = self.dtype or {}
892+
if col.name in dtype:
893+
return self.dtype[col.name]
894+
888895
if com.is_datetime64_dtype(col):
889896
try:
890897
tz = col.tzinfo
@@ -1099,7 +1106,7 @@ def read_query(self, sql, index_col=None, coerce_float=True,
10991106
read_sql = read_query
11001107

11011108
def to_sql(self, frame, name, if_exists='fail', index=True,
1102-
index_label=None, schema=None, chunksize=None):
1109+
index_label=None, schema=None, chunksize=None, dtype=None):
11031110
"""
11041111
Write records stored in a DataFrame to a SQL database.
11051112
@@ -1125,11 +1132,20 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
11251132
chunksize : int, default None
11261133
If not None, then rows will be written in batches of this size at a
11271134
time. If None, all rows will be written at once.
1128-
1135+
dtype : dictionary of column name to SQLAlchemy type, default None
1136+
Optional datatypes for SQL columns.
1137+
11291138
"""
1139+
if dtype is not None:
1140+
import sqlalchemy.sql.type_api as type_api
1141+
for col, my_type in dtype.items():
1142+
if not issubclass(my_type, type_api.TypeEngine):
1143+
raise ValueError('The type of %s is not a SQLAlchemy '
1144+
'type ' % col)
1145+
11301146
table = SQLTable(name, self, frame=frame, index=index,
11311147
if_exists=if_exists, index_label=index_label,
1132-
schema=schema)
1148+
schema=schema, dtype=dtype)
11331149
table.create()
11341150
table.insert(chunksize)
11351151
# check for potentially case sensitivity issues (GH7815)
@@ -1297,6 +1313,9 @@ def _create_table_setup(self):
12971313
return create_stmts
12981314

12991315
def _sql_type_name(self, col):
1316+
dtype = self.dtype or {}
1317+
if col.name in dtype:
1318+
return dtype[col.name]
13001319
pytype = col.dtype.type
13011320
pytype_name = "text"
13021321
if issubclass(pytype, np.floating):
@@ -1424,7 +1443,7 @@ def _fetchall_as_list(self, cur):
14241443
return result
14251444

14261445
def to_sql(self, frame, name, if_exists='fail', index=True,
1427-
index_label=None, schema=None, chunksize=None):
1446+
index_label=None, schema=None, chunksize=None, dtype=None):
14281447
"""
14291448
Write records stored in a DataFrame to a SQL database.
14301449
@@ -1448,10 +1467,19 @@ def to_sql(self, frame, name, if_exists='fail', index=True,
14481467
chunksize : int, default None
14491468
If not None, then rows will be written in batches of this
14501469
size at a time. If None, all rows will be written at once.
1470+
dtype : dictionary of column_name to SQLite string type, default None
1471+
optional datatypes for SQL columns.
14511472
14521473
"""
1474+
if dtype is not None:
1475+
for col, my_type in dtype.items():
1476+
if not isinstance(my_type, str):
1477+
raise ValueError('%s (%s) not a string' % (
1478+
col, str(my_type)))
1479+
14531480
table = SQLiteTable(name, self, frame=frame, index=index,
1454-
if_exists=if_exists, index_label=index_label)
1481+
if_exists=if_exists, index_label=index_label,
1482+
dtype=dtype)
14551483
table.create()
14561484
table.insert(chunksize)
14571485

pandas/io/tests/test_sql.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
try:
4343
import sqlalchemy
44+
import sqlalchemy.schema
45+
import sqlalchemy.sql.sqltypes as sqltypes
4446
SQLALCHEMY_INSTALLED = True
4547
except ImportError:
4648
SQLALCHEMY_INSTALLED = False
@@ -339,7 +341,7 @@ def _transaction_test(self):
339341
self.pandasSQL.execute("CREATE TABLE test_trans (A INT, B TEXT)")
340342

341343
ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')"
342-
344+
343345
# Make sure when transaction is rolled back, no rows get inserted
344346
try:
345347
with self.pandasSQL.run_transaction() as trans:
@@ -350,7 +352,7 @@ def _transaction_test(self):
350352
pass
351353
res = self.pandasSQL.read_query('SELECT * FROM test_trans')
352354
self.assertEqual(len(res), 0)
353-
355+
354356
# Make sure when transaction is committed, rows do get inserted
355357
with self.pandasSQL.run_transaction() as trans:
356358
trans.execute(ins_sql)
@@ -1167,6 +1169,26 @@ def test_get_schema_create_table(self):
11671169
tm.assert_frame_equal(returned_df, blank_test_df)
11681170
self.drop_table(tbl)
11691171

1172+
def test_dtype(self):
1173+
cols = ['A', 'B']
1174+
data = [(0.8, True),
1175+
(0.9, None)]
1176+
df = DataFrame(data, columns=cols)
1177+
df.to_sql('dtype_test', self.conn)
1178+
df.to_sql('dtype_test2', self.conn, dtype={'B': sqlalchemy.Boolean})
1179+
meta = sqlalchemy.schema.MetaData(bind=self.conn)
1180+
meta.reflect()
1181+
self.assertTrue(isinstance(meta.tables['dtype_test'].columns['B'].type,
1182+
sqltypes.TEXT))
1183+
if self.flavor == 'mysql':
1184+
my_type = sqltypes.Integer
1185+
else:
1186+
my_type = sqltypes.Boolean
1187+
self.assertTrue(isinstance(meta.tables['dtype_test2'].columns['B'].type,
1188+
my_type))
1189+
self.assertRaises(ValueError, df.to_sql,
1190+
'error', self.conn, dtype={'B': bool})
1191+
11701192

11711193
class TestSQLiteAlchemy(_TestSQLAlchemy):
11721194
"""
@@ -1467,7 +1489,7 @@ def test_datetime_time(self):
14671489
if self.flavor == 'sqlite':
14681490
self.assertRaises(sqlite3.InterfaceError, sql.to_sql, df,
14691491
'test_time', self.conn)
1470-
1492+
14711493
def _get_index_columns(self, tbl_name):
14721494
ixs = sql.read_sql_query(
14731495
"SELECT * FROM sqlite_master WHERE type = 'index' " +
@@ -1485,6 +1507,28 @@ def test_to_sql_save_index(self):
14851507
def test_transactions(self):
14861508
self._transaction_test()
14871509

1510+
def test_dtype(self):
1511+
if self.flavor == 'mysql':
1512+
raise nose.SkipTest('Not applicable to MySQL legacy')
1513+
cols = ['A', 'B']
1514+
data = [(0.8, True),
1515+
(0.9, None)]
1516+
df = DataFrame(data, columns=cols)
1517+
df.to_sql('dtype_test', self.conn)
1518+
df.to_sql('dtype_test2', self.conn, dtype={'B': 'bool'})
1519+
1520+
def get_column_type(table, column):
1521+
recs = self.conn.execute('PRAGMA table_info(%s)' % table)
1522+
for cid, name, ctype, not_null, default, pk in recs:
1523+
if name == column:
1524+
return ctype
1525+
raise ValueError('Table %s, column %s not found' % (table, column))
1526+
1527+
self.assertEqual(get_column_type('dtype_test', 'B'), 'TEXT')
1528+
self.assertEqual(get_column_type('dtype_test2', 'B'), 'bool')
1529+
self.assertRaises(ValueError, df.to_sql,
1530+
'error', self.conn, dtype={'B': bool})
1531+
14881532
class TestMySQLLegacy(TestSQLiteFallback):
14891533
"""
14901534
Test the legacy mode against a MySQL database.

0 commit comments

Comments
 (0)