diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 65bfd8289fe3d..5ddf68cad8baf 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -43,6 +43,7 @@ Other enhancements - Added ``end`` and ``end_day`` options for ``origin`` in :meth:`DataFrame.resample` (:issue:`37804`) - Improve error message when ``usecols`` and ``names`` do not match for :func:`read_csv` and ``engine="c"`` (:issue:`29042`) - Improved consistency of error message when passing an invalid ``win_type`` argument in :class:`Window` (:issue:`15969`) +- :func:`pandas.read_sql_query` now accepts a ``dtype`` argument to cast the columnar data from the SQL database based on user input (:issue:`10285`) .. --------------------------------------------------------------------------- diff --git a/pandas/_typing.py b/pandas/_typing.py index c79942c48509e..64452bf337361 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -71,13 +71,6 @@ ] Timezone = Union[str, tzinfo] -# other - -Dtype = Union[ - "ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]] -] -DtypeObj = Union[np.dtype, "ExtensionDtype"] - # FrameOrSeriesUnion means either a DataFrame or a Series. E.g. # `def func(a: FrameOrSeriesUnion) -> FrameOrSeriesUnion: ...` means that if a Series # is passed in, either a Series or DataFrame is returned, and if a DataFrame is passed @@ -100,6 +93,14 @@ JSONSerializable = Optional[Union[PythonScalar, List, Dict]] Axes = Collection +# dtypes +Dtype = Union[ + "ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]] +] +# DtypeArg specifies all allowable dtypes in a functions its dtype argument +DtypeArg = Union[Dtype, Dict[Label, Dtype]] +DtypeObj = Union[np.dtype, "ExtensionDtype"] + # For functions like rename that convert one label to another Renamer = Union[Mapping[Label, Any], Callable[[Label], Label]] diff --git a/pandas/io/sql.py b/pandas/io/sql.py index a6708896f4f2e..0ad9140f2a757 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -13,6 +13,7 @@ import numpy as np import pandas._libs.lib as lib +from pandas._typing import DtypeArg from pandas.core.dtypes.common import is_datetime64tz_dtype, is_dict_like, is_list_like from pandas.core.dtypes.dtypes import DatetimeTZDtype @@ -132,10 +133,14 @@ def _wrap_result( index_col=None, coerce_float: bool = True, parse_dates=None, + dtype: Optional[DtypeArg] = None, ): """Wrap result set of query in a DataFrame.""" frame = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float) + if dtype: + frame = frame.astype(dtype) + frame = _parse_date_columns(frame, parse_dates) if index_col is not None: @@ -308,6 +313,7 @@ def read_sql_query( params=None, parse_dates=None, chunksize: None = None, + dtype: Optional[DtypeArg] = None, ) -> DataFrame: ... @@ -321,6 +327,7 @@ def read_sql_query( params=None, parse_dates=None, chunksize: int = 1, + dtype: Optional[DtypeArg] = None, ) -> Iterator[DataFrame]: ... @@ -333,6 +340,7 @@ def read_sql_query( params=None, parse_dates=None, chunksize: Optional[int] = None, + dtype: Optional[DtypeArg] = None, ) -> Union[DataFrame, Iterator[DataFrame]]: """ Read SQL query into a DataFrame. @@ -371,6 +379,9 @@ def read_sql_query( chunksize : int, default None If specified, return an iterator where `chunksize` is the number of rows to include in each chunk. + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’} Returns ------- @@ -394,6 +405,7 @@ def read_sql_query( coerce_float=coerce_float, parse_dates=parse_dates, chunksize=chunksize, + dtype=dtype, ) @@ -1307,6 +1319,7 @@ def _query_iterator( index_col=None, coerce_float=True, parse_dates=None, + dtype: Optional[DtypeArg] = None, ): """Return generator through chunked result set""" while True: @@ -1320,6 +1333,7 @@ def _query_iterator( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + dtype=dtype, ) def read_query( @@ -1330,6 +1344,7 @@ def read_query( parse_dates=None, params=None, chunksize: Optional[int] = None, + dtype: Optional[DtypeArg] = None, ): """ Read SQL query into a DataFrame. @@ -1361,6 +1376,11 @@ def read_query( chunksize : int, default None If specified, return an iterator where `chunksize` is the number of rows to include in each chunk. + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’} + + .. versionadded:: 1.3.0 Returns ------- @@ -1385,6 +1405,7 @@ def read_query( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + dtype=dtype, ) else: data = result.fetchall() @@ -1394,6 +1415,7 @@ def read_query( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + dtype=dtype, ) return frame @@ -1799,6 +1821,7 @@ def _query_iterator( index_col=None, coerce_float: bool = True, parse_dates=None, + dtype: Optional[DtypeArg] = None, ): """Return generator through chunked result set""" while True: @@ -1815,6 +1838,7 @@ def _query_iterator( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + dtype=dtype, ) def read_query( @@ -1825,6 +1849,7 @@ def read_query( params=None, parse_dates=None, chunksize: Optional[int] = None, + dtype: Optional[DtypeArg] = None, ): args = _convert_params(sql, params) @@ -1839,6 +1864,7 @@ def read_query( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + dtype=dtype, ) else: data = self._fetchall_as_list(cursor) @@ -1850,6 +1876,7 @@ def read_query( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + dtype=dtype, ) return frame diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 9df8ff9e8ee06..fdd42ec0cc5ab 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -937,6 +937,27 @@ def test_multiindex_roundtrip(self): ) tm.assert_frame_equal(df, result, check_index_type=True) + @pytest.mark.parametrize( + "dtype", + [ + None, + int, + float, + {"A": int, "B": float}, + ], + ) + def test_dtype_argument(self, dtype): + # GH10285 Add dtype argument to read_sql_query + df = DataFrame([[1.2, 3.4], [5.6, 7.8]], columns=["A", "B"]) + df.to_sql("test_dtype_argument", self.conn) + + expected = df.astype(dtype) + result = sql.read_sql_query( + "SELECT A, B FROM test_dtype_argument", con=self.conn, dtype=dtype + ) + + tm.assert_frame_equal(result, expected) + def test_integer_col_names(self): df = DataFrame([[1, 2], [3, 4]], columns=[0, 1]) sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists="replace")