-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: Add dtype argument to read_sql_query (GH10285) #37546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
bcbe5ea
9c4f034
620c0ab
bcef60e
5c88e5c
24308c4
d6cc4b7
5de64f2
e9be344
d7d4439
dbf1f5f
a4e7cdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,13 +71,6 @@ | |
] | ||
Timezone = Union[str, tzinfo] | ||
|
||
# other | ||
|
||
Dtype = Union[ | ||
"ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]] | ||
] | ||
DtypeObj = Union[np.dtype, "ExtensionDtype"] | ||
|
||
# FrameOrSeriesUnion means either a DataFrame or a Series. E.g. | ||
# `def func(a: FrameOrSeriesUnion) -> FrameOrSeriesUnion: ...` means that if a Series | ||
# is passed in, either a Series or DataFrame is returned, and if a DataFrame is passed | ||
|
@@ -99,6 +92,14 @@ | |
JSONSerializable = Optional[Union[PythonScalar, List, Dict]] | ||
Axes = Collection | ||
|
||
# dtypes | ||
|
||
Dtype = Union[ | ||
"ExtensionDtype", str, np.dtype, Type[Union[str, float, int, complex, bool, object]] | ||
] | ||
DtypeArg = Optional[Union[Dtype, Dict[Label, Dtype]]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use Optional in the spec itself |
||
DtypeObj = Union[np.dtype, "ExtensionDtype"] | ||
|
||
# For functions like rename that convert one label to another | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved the dtype block since |
||
Renamer = Union[Mapping[Label, Any], Callable[[Label], Label]] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import numpy as np | ||
|
||
import pandas._libs.lib as lib | ||
from pandas._typing import DtypeArg | ||
|
||
from pandas.core.dtypes.common import is_datetime64tz_dtype, is_dict_like, is_list_like | ||
from pandas.core.dtypes.dtypes import DatetimeTZDtype | ||
|
@@ -124,10 +125,20 @@ def _parse_date_columns(data_frame, parse_dates): | |
return data_frame | ||
|
||
|
||
def _wrap_result(data, columns, index_col=None, coerce_float=True, parse_dates=None): | ||
def _wrap_result( | ||
data, | ||
columns, | ||
index_col=None, | ||
coerce_float=True, | ||
parse_dates=None, | ||
dtype: DtypeArg = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional[DtypeArg] for all of these |
||
): | ||
"""Wrap result set of query in a DataFrame.""" | ||
frame = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float) | ||
|
||
if dtype: | ||
frame = frame.astype(dtype) | ||
|
||
frame = _parse_date_columns(frame, parse_dates) | ||
|
||
if index_col is not None: | ||
|
@@ -300,6 +311,7 @@ def read_sql_query( | |
params=None, | ||
parse_dates=None, | ||
chunksize: None = None, | ||
dtype: DtypeArg = None, | ||
) -> DataFrame: | ||
... | ||
|
||
|
@@ -313,6 +325,7 @@ def read_sql_query( | |
params=None, | ||
parse_dates=None, | ||
chunksize: int = 1, | ||
dtype: DtypeArg = None, | ||
) -> Iterator[DataFrame]: | ||
... | ||
|
||
|
@@ -325,6 +338,7 @@ def read_sql_query( | |
params=None, | ||
parse_dates=None, | ||
chunksize: Optional[int] = None, | ||
dtype: DtypeArg = None, | ||
) -> Union[DataFrame, Iterator[DataFrame]]: | ||
""" | ||
Read SQL query into a DataFrame. | ||
|
@@ -363,6 +377,9 @@ def read_sql_query( | |
chunksize : int, default None | ||
If specified, return an iterator where `chunksize` is the number of | ||
rows to include in each chunk. | ||
dtype : Type name or dict of columns | ||
Data type for data or columns. E.g. np.float64 or | ||
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need a versionadded 1.3 here. ok to add in next PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I see i didn't commit that change. But will indeed add it to the follow on |
||
|
||
Returns | ||
------- | ||
|
@@ -386,6 +403,7 @@ def read_sql_query( | |
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
chunksize=chunksize, | ||
dtype=dtype, | ||
) | ||
|
||
|
||
|
@@ -1230,7 +1248,13 @@ def read_table( | |
|
||
@staticmethod | ||
def _query_iterator( | ||
result, chunksize, columns, index_col=None, coerce_float=True, parse_dates=None | ||
result, | ||
chunksize, | ||
columns, | ||
index_col=None, | ||
coerce_float=True, | ||
parse_dates=None, | ||
dtype: DtypeArg = None, | ||
): | ||
"""Return generator through chunked result set""" | ||
while True: | ||
|
@@ -1244,6 +1268,7 @@ def _query_iterator( | |
index_col=index_col, | ||
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
dtype=dtype, | ||
) | ||
|
||
def read_query( | ||
|
@@ -1254,6 +1279,7 @@ def read_query( | |
parse_dates=None, | ||
params=None, | ||
chunksize=None, | ||
dtype: DtypeArg = None, | ||
): | ||
""" | ||
Read SQL query into a DataFrame. | ||
|
@@ -1285,6 +1311,9 @@ def read_query( | |
chunksize : int, default None | ||
If specified, return an iterator where `chunksize` is the number | ||
of rows to include in each chunk. | ||
dtype : Type name or dict of columns | ||
Data type for data or columns. E.g. np.float64 or | ||
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. versionadded 1.3 |
||
|
||
Returns | ||
------- | ||
|
@@ -1309,6 +1338,7 @@ def read_query( | |
index_col=index_col, | ||
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
dtype=dtype, | ||
) | ||
else: | ||
data = result.fetchall() | ||
|
@@ -1318,6 +1348,7 @@ def read_query( | |
index_col=index_col, | ||
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
dtype=dtype, | ||
) | ||
return frame | ||
|
||
|
@@ -1717,7 +1748,13 @@ def execute(self, *args, **kwargs): | |
|
||
@staticmethod | ||
def _query_iterator( | ||
cursor, chunksize, columns, index_col=None, coerce_float=True, parse_dates=None | ||
cursor, | ||
chunksize, | ||
columns, | ||
index_col=None, | ||
coerce_float=True, | ||
parse_dates=None, | ||
dtype: DtypeArg = None, | ||
): | ||
"""Return generator through chunked result set""" | ||
while True: | ||
|
@@ -1734,6 +1771,7 @@ def _query_iterator( | |
index_col=index_col, | ||
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
dtype=dtype, | ||
) | ||
|
||
def read_query( | ||
|
@@ -1744,6 +1782,7 @@ def read_query( | |
params=None, | ||
parse_dates=None, | ||
chunksize=None, | ||
dtype: DtypeArg = None, | ||
): | ||
|
||
args = _convert_params(sql, params) | ||
|
@@ -1758,6 +1797,7 @@ def read_query( | |
index_col=index_col, | ||
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
dtype=dtype, | ||
) | ||
else: | ||
data = self._fetchall_as_list(cursor) | ||
|
@@ -1769,6 +1809,7 @@ def read_query( | |
index_col=index_col, | ||
coerce_float=coerce_float, | ||
parse_dates=parse_dates, | ||
dtype=dtype, | ||
) | ||
return frame | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -937,6 +937,25 @@ def test_multiindex_roundtrip(self): | |
) | ||
tm.assert_frame_equal(df, result, check_index_type=True) | ||
|
||
@pytest.mark.parametrize( | ||
"dtype, expected", | ||
[ | ||
(None, [float, float]), | ||
(int, [int, int]), | ||
(float, [float, float]), | ||
({"SepalLength": int, "SepalWidth": float}, [int, float]), | ||
], | ||
) | ||
def test_dtype_argument(self, dtype, expected): | ||
# GH10285 Add dtype argument to read_sql_query | ||
result = sql.read_sql_query( | ||
"SELECT SepalLength, SepalWidth FROM iris", self.conn, dtype=dtype | ||
) | ||
assert result.dtypes.to_dict() == { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you constructed an expected frame and use tm.assert_frame_equal |
||
"SepalLength": expected[0], | ||
"SepalWidth": expected[1], | ||
} | ||
|
||
def test_integer_col_names(self): | ||
df = DataFrame([[1, 2], [3, 4]], columns=[0, 1]) | ||
sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists="replace") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment on what DtypeArg is / supposed to be used