Skip to content

Commit c995254

Browse files
avinashpanchamluckyvs1
authored andcommitted
ENH: Add dtype argument to read_sql_query (GH10285) (pandas-dev#37546)
1 parent 785ec4d commit c995254

File tree

4 files changed

+57
-7
lines changed

4 files changed

+57
-7
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Other enhancements
4343
- Added ``end`` and ``end_day`` options for ``origin`` in :meth:`DataFrame.resample` (:issue:`37804`)
4444
- Improve error message when ``usecols`` and ``names`` do not match for :func:`read_csv` and ``engine="c"`` (:issue:`29042`)
4545
- Improved consistency of error message when passing an invalid ``win_type`` argument in :class:`Window` (:issue:`15969`)
46+
- :func:`pandas.read_sql_query` now accepts a ``dtype`` argument to cast the columnar data from the SQL database based on user input (:issue:`10285`)
4647

4748
.. ---------------------------------------------------------------------------
4849

pandas/_typing.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@
7171
]
7272
Timezone = Union[str, tzinfo]
7373

74-
# other
75-
76-
Dtype = Union[
77-
"ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]]
78-
]
79-
DtypeObj = Union[np.dtype, "ExtensionDtype"]
80-
8174
# FrameOrSeriesUnion means either a DataFrame or a Series. E.g.
8275
# `def func(a: FrameOrSeriesUnion) -> FrameOrSeriesUnion: ...` means that if a Series
8376
# is passed in, either a Series or DataFrame is returned, and if a DataFrame is passed
@@ -100,6 +93,14 @@
10093
JSONSerializable = Optional[Union[PythonScalar, List, Dict]]
10194
Axes = Collection
10295

96+
# dtypes
97+
Dtype = Union[
98+
"ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]]
99+
]
100+
# DtypeArg specifies all allowable dtypes in a functions its dtype argument
101+
DtypeArg = Union[Dtype, Dict[Label, Dtype]]
102+
DtypeObj = Union[np.dtype, "ExtensionDtype"]
103+
103104
# For functions like rename that convert one label to another
104105
Renamer = Union[Mapping[Label, Any], Callable[[Label], Label]]
105106

pandas/io/sql.py

+27
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414

1515
import pandas._libs.lib as lib
16+
from pandas._typing import DtypeArg
1617

1718
from pandas.core.dtypes.common import is_datetime64tz_dtype, is_dict_like, is_list_like
1819
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -132,10 +133,14 @@ def _wrap_result(
132133
index_col=None,
133134
coerce_float: bool = True,
134135
parse_dates=None,
136+
dtype: Optional[DtypeArg] = None,
135137
):
136138
"""Wrap result set of query in a DataFrame."""
137139
frame = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float)
138140

141+
if dtype:
142+
frame = frame.astype(dtype)
143+
139144
frame = _parse_date_columns(frame, parse_dates)
140145

141146
if index_col is not None:
@@ -308,6 +313,7 @@ def read_sql_query(
308313
params=None,
309314
parse_dates=None,
310315
chunksize: None = None,
316+
dtype: Optional[DtypeArg] = None,
311317
) -> DataFrame:
312318
...
313319

@@ -321,6 +327,7 @@ def read_sql_query(
321327
params=None,
322328
parse_dates=None,
323329
chunksize: int = 1,
330+
dtype: Optional[DtypeArg] = None,
324331
) -> Iterator[DataFrame]:
325332
...
326333

@@ -333,6 +340,7 @@ def read_sql_query(
333340
params=None,
334341
parse_dates=None,
335342
chunksize: Optional[int] = None,
343+
dtype: Optional[DtypeArg] = None,
336344
) -> Union[DataFrame, Iterator[DataFrame]]:
337345
"""
338346
Read SQL query into a DataFrame.
@@ -371,6 +379,9 @@ def read_sql_query(
371379
chunksize : int, default None
372380
If specified, return an iterator where `chunksize` is the number of
373381
rows to include in each chunk.
382+
dtype : Type name or dict of columns
383+
Data type for data or columns. E.g. np.float64 or
384+
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’}
374385
375386
Returns
376387
-------
@@ -394,6 +405,7 @@ def read_sql_query(
394405
coerce_float=coerce_float,
395406
parse_dates=parse_dates,
396407
chunksize=chunksize,
408+
dtype=dtype,
397409
)
398410

399411

@@ -1307,6 +1319,7 @@ def _query_iterator(
13071319
index_col=None,
13081320
coerce_float=True,
13091321
parse_dates=None,
1322+
dtype: Optional[DtypeArg] = None,
13101323
):
13111324
"""Return generator through chunked result set"""
13121325
while True:
@@ -1320,6 +1333,7 @@ def _query_iterator(
13201333
index_col=index_col,
13211334
coerce_float=coerce_float,
13221335
parse_dates=parse_dates,
1336+
dtype=dtype,
13231337
)
13241338

13251339
def read_query(
@@ -1330,6 +1344,7 @@ def read_query(
13301344
parse_dates=None,
13311345
params=None,
13321346
chunksize: Optional[int] = None,
1347+
dtype: Optional[DtypeArg] = None,
13331348
):
13341349
"""
13351350
Read SQL query into a DataFrame.
@@ -1361,6 +1376,11 @@ def read_query(
13611376
chunksize : int, default None
13621377
If specified, return an iterator where `chunksize` is the number
13631378
of rows to include in each chunk.
1379+
dtype : Type name or dict of columns
1380+
Data type for data or columns. E.g. np.float64 or
1381+
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’}
1382+
1383+
.. versionadded:: 1.3.0
13641384
13651385
Returns
13661386
-------
@@ -1385,6 +1405,7 @@ def read_query(
13851405
index_col=index_col,
13861406
coerce_float=coerce_float,
13871407
parse_dates=parse_dates,
1408+
dtype=dtype,
13881409
)
13891410
else:
13901411
data = result.fetchall()
@@ -1394,6 +1415,7 @@ def read_query(
13941415
index_col=index_col,
13951416
coerce_float=coerce_float,
13961417
parse_dates=parse_dates,
1418+
dtype=dtype,
13971419
)
13981420
return frame
13991421

@@ -1799,6 +1821,7 @@ def _query_iterator(
17991821
index_col=None,
18001822
coerce_float: bool = True,
18011823
parse_dates=None,
1824+
dtype: Optional[DtypeArg] = None,
18021825
):
18031826
"""Return generator through chunked result set"""
18041827
while True:
@@ -1815,6 +1838,7 @@ def _query_iterator(
18151838
index_col=index_col,
18161839
coerce_float=coerce_float,
18171840
parse_dates=parse_dates,
1841+
dtype=dtype,
18181842
)
18191843

18201844
def read_query(
@@ -1825,6 +1849,7 @@ def read_query(
18251849
params=None,
18261850
parse_dates=None,
18271851
chunksize: Optional[int] = None,
1852+
dtype: Optional[DtypeArg] = None,
18281853
):
18291854

18301855
args = _convert_params(sql, params)
@@ -1839,6 +1864,7 @@ def read_query(
18391864
index_col=index_col,
18401865
coerce_float=coerce_float,
18411866
parse_dates=parse_dates,
1867+
dtype=dtype,
18421868
)
18431869
else:
18441870
data = self._fetchall_as_list(cursor)
@@ -1850,6 +1876,7 @@ def read_query(
18501876
index_col=index_col,
18511877
coerce_float=coerce_float,
18521878
parse_dates=parse_dates,
1879+
dtype=dtype,
18531880
)
18541881
return frame
18551882

pandas/tests/io/test_sql.py

+21
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,27 @@ def test_multiindex_roundtrip(self):
937937
)
938938
tm.assert_frame_equal(df, result, check_index_type=True)
939939

940+
@pytest.mark.parametrize(
941+
"dtype",
942+
[
943+
None,
944+
int,
945+
float,
946+
{"A": int, "B": float},
947+
],
948+
)
949+
def test_dtype_argument(self, dtype):
950+
# GH10285 Add dtype argument to read_sql_query
951+
df = DataFrame([[1.2, 3.4], [5.6, 7.8]], columns=["A", "B"])
952+
df.to_sql("test_dtype_argument", self.conn)
953+
954+
expected = df.astype(dtype)
955+
result = sql.read_sql_query(
956+
"SELECT A, B FROM test_dtype_argument", con=self.conn, dtype=dtype
957+
)
958+
959+
tm.assert_frame_equal(result, expected)
960+
940961
def test_integer_col_names(self):
941962
df = DataFrame([[1, 2], [3, 4]], columns=[0, 1])
942963
sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists="replace")

0 commit comments

Comments
 (0)