Skip to content

ENH: Add dtype argument to read_sql_query (GH10285) #37546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 23, 2020
Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Other enhancements

- Added :meth:`MultiIndex.dtypes` (:issue:`37062`)
- Improve error message when ``usecols`` and ``names`` do not match for :func:`read_csv` and ``engine="c"`` (:issue:`29042`)
- :func:`pandas.read_sql_query` now accepts a ``dtype`` argument to cast the columnar data from the SQL database based on user input (:issue:`10285`)

.. ---------------------------------------------------------------------------

Expand Down
15 changes: 8 additions & 7 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@
]
Timezone = Union[str, tzinfo]

# other

Dtype = Union[
"ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]]
]
DtypeObj = Union[np.dtype, "ExtensionDtype"]

# FrameOrSeriesUnion means either a DataFrame or a Series. E.g.
# `def func(a: FrameOrSeriesUnion) -> FrameOrSeriesUnion: ...` means that if a Series
# is passed in, either a Series or DataFrame is returned, and if a DataFrame is passed
Expand All @@ -99,6 +92,14 @@
JSONSerializable = Optional[Union[PythonScalar, List, Dict]]
Axes = Collection

# dtypes

Dtype = Union[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on what DtypeArg is / supposed to be used

"ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]]
]
DtypeArg = Optional[Union[Dtype, Dict[Label, Dtype]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use Optional in the spec itself

DtypeObj = Union[np.dtype, "ExtensionDtype"]

# For functions like rename that convert one label to another
Copy link
Contributor Author

@avinashpancham avinashpancham Dec 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the dtype block since Label was defined later in the file

Renamer = Union[Mapping[Label, Any], Callable[[Label], Label]]

Expand Down
47 changes: 44 additions & 3 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np

import pandas._libs.lib as lib
from pandas._typing import DtypeArg

from pandas.core.dtypes.common import is_datetime64tz_dtype, is_dict_like, is_list_like
from pandas.core.dtypes.dtypes import DatetimeTZDtype
Expand Down Expand Up @@ -124,10 +125,20 @@ def _parse_date_columns(data_frame, parse_dates):
return data_frame


def _wrap_result(data, columns, index_col=None, coerce_float=True, parse_dates=None):
def _wrap_result(
data,
columns,
index_col=None,
coerce_float=True,
parse_dates=None,
dtype: DtypeArg = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional[DtypeArg] for all of these

):
"""Wrap result set of query in a DataFrame."""
frame = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float)

if dtype:
frame = frame.astype(dtype)

frame = _parse_date_columns(frame, parse_dates)

if index_col is not None:
Expand Down Expand Up @@ -300,6 +311,7 @@ def read_sql_query(
params=None,
parse_dates=None,
chunksize: None = None,
dtype: DtypeArg = None,
) -> DataFrame:
...

Expand All @@ -313,6 +325,7 @@ def read_sql_query(
params=None,
parse_dates=None,
chunksize: int = 1,
dtype: DtypeArg = None,
) -> Iterator[DataFrame]:
...

Expand All @@ -325,6 +338,7 @@ def read_sql_query(
params=None,
parse_dates=None,
chunksize: Optional[int] = None,
dtype: DtypeArg = None,
) -> Union[DataFrame, Iterator[DataFrame]]:
"""
Read SQL query into a DataFrame.
Expand Down Expand Up @@ -363,6 +377,9 @@ def read_sql_query(
chunksize : int, default None
If specified, return an iterator where `chunksize` is the number of
rows to include in each chunk.
dtype : Type name or dict of columns
Data type for data or columns. E.g. np.float64 or
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need a versionadded 1.3 here. ok to add in next PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I see i didn't commit that change. But will indeed add it to the follow on


Returns
-------
Expand All @@ -386,6 +403,7 @@ def read_sql_query(
coerce_float=coerce_float,
parse_dates=parse_dates,
chunksize=chunksize,
dtype=dtype,
)


Expand Down Expand Up @@ -1230,7 +1248,13 @@ def read_table(

@staticmethod
def _query_iterator(
result, chunksize, columns, index_col=None, coerce_float=True, parse_dates=None
result,
chunksize,
columns,
index_col=None,
coerce_float=True,
parse_dates=None,
dtype: DtypeArg = None,
):
"""Return generator through chunked result set"""
while True:
Expand All @@ -1244,6 +1268,7 @@ def _query_iterator(
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)

def read_query(
Expand All @@ -1254,6 +1279,7 @@ def read_query(
parse_dates=None,
params=None,
chunksize=None,
dtype: DtypeArg = None,
):
"""
Read SQL query into a DataFrame.
Expand Down Expand Up @@ -1285,6 +1311,9 @@ def read_query(
chunksize : int, default None
If specified, return an iterator where `chunksize` is the number
of rows to include in each chunk.
dtype : Type name or dict of columns
Data type for data or columns. E.g. np.float64 or
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionadded 1.3


Returns
-------
Expand All @@ -1309,6 +1338,7 @@ def read_query(
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
else:
data = result.fetchall()
Expand All @@ -1318,6 +1348,7 @@ def read_query(
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
return frame

Expand Down Expand Up @@ -1717,7 +1748,13 @@ def execute(self, *args, **kwargs):

@staticmethod
def _query_iterator(
cursor, chunksize, columns, index_col=None, coerce_float=True, parse_dates=None
cursor,
chunksize,
columns,
index_col=None,
coerce_float=True,
parse_dates=None,
dtype: DtypeArg = None,
):
"""Return generator through chunked result set"""
while True:
Expand All @@ -1734,6 +1771,7 @@ def _query_iterator(
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)

def read_query(
Expand All @@ -1744,6 +1782,7 @@ def read_query(
params=None,
parse_dates=None,
chunksize=None,
dtype: DtypeArg = None,
):

args = _convert_params(sql, params)
Expand All @@ -1758,6 +1797,7 @@ def read_query(
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
else:
data = self._fetchall_as_list(cursor)
Expand All @@ -1769,6 +1809,7 @@ def read_query(
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype=dtype,
)
return frame

Expand Down
19 changes: 19 additions & 0 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,25 @@ def test_multiindex_roundtrip(self):
)
tm.assert_frame_equal(df, result, check_index_type=True)

@pytest.mark.parametrize(
"dtype, expected",
[
(None, [float, float]),
(int, [int, int]),
(float, [float, float]),
({"SepalLength": int, "SepalWidth": float}, [int, float]),
],
)
def test_dtype_argument(self, dtype, expected):
# GH10285 Add dtype argument to read_sql_query
result = sql.read_sql_query(
"SELECT SepalLength, SepalWidth FROM iris", self.conn, dtype=dtype
)
assert result.dtypes.to_dict() == {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you constructed an expected frame and use tm.assert_frame_equal

"SepalLength": expected[0],
"SepalWidth": expected[1],
}

def test_integer_col_names(self):
df = DataFrame([[1, 2], [3, 4]], columns=[0, 1])
sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists="replace")
Expand Down