Skip to content

Commit 69a8150

Browse files
authored
ENH: Add dtype to read sql to be consistent with read_sql_query (pandas-dev#50797)
* ENH: Add dtype to read sql to be consistent with read_sql_query * Add gh ref * Fix docstring * Add test
1 parent 6ecb52e commit 69a8150

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ Other enhancements
165165
- Improved error message in :func:`to_datetime` for non-ISO8601 formats, informing users about the position of the first error (:issue:`50361`)
166166
- Improved error message when trying to align :class:`DataFrame` objects (for example, in :func:`DataFrame.compare`) to clarify that "identically labelled" refers to both index and columns (:issue:`50083`)
167167
- Added :meth:`DatetimeIndex.as_unit` and :meth:`TimedeltaIndex.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`50616`)
168+
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
168169
-
169170

170171
.. ---------------------------------------------------------------------------

pandas/io/sql.py

+11
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def read_sql(
470470
columns: list[str] = ...,
471471
chunksize: None = ...,
472472
use_nullable_dtypes: bool = ...,
473+
dtype: DtypeArg | None = None,
473474
) -> DataFrame:
474475
...
475476

@@ -485,6 +486,7 @@ def read_sql(
485486
columns: list[str] = ...,
486487
chunksize: int = ...,
487488
use_nullable_dtypes: bool = ...,
489+
dtype: DtypeArg | None = None,
488490
) -> Iterator[DataFrame]:
489491
...
490492

@@ -499,6 +501,7 @@ def read_sql(
499501
columns: list[str] | None = None,
500502
chunksize: int | None = None,
501503
use_nullable_dtypes: bool = False,
504+
dtype: DtypeArg | None = None,
502505
) -> DataFrame | Iterator[DataFrame]:
503506
"""
504507
Read SQL query or database table into a DataFrame.
@@ -552,6 +555,12 @@ def read_sql(
552555
implementation, even if no nulls are present.
553556
554557
.. versionadded:: 2.0
558+
dtype : Type name or dict of columns
559+
Data type for data or columns. E.g. np.float64 or
560+
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’}.
561+
The argument is ignored if a table is passed instead of a query.
562+
563+
.. versionadded:: 2.0.0
555564
556565
Returns
557566
-------
@@ -632,6 +641,7 @@ def read_sql(
632641
parse_dates=parse_dates,
633642
chunksize=chunksize,
634643
use_nullable_dtypes=use_nullable_dtypes,
644+
dtype=dtype,
635645
)
636646

637647
try:
@@ -659,6 +669,7 @@ def read_sql(
659669
parse_dates=parse_dates,
660670
chunksize=chunksize,
661671
use_nullable_dtypes=use_nullable_dtypes,
672+
dtype=dtype,
662673
)
663674

664675

pandas/tests/io/test_sql.py

+24
Original file line numberDiff line numberDiff line change
@@ -2394,6 +2394,30 @@ def test_chunksize_empty_dtypes(self):
23942394
):
23952395
tm.assert_frame_equal(result, expected)
23962396

2397+
@pytest.mark.parametrize("use_nullable_dtypes", [True, False])
2398+
@pytest.mark.parametrize("func", ["read_sql", "read_sql_query"])
2399+
def test_read_sql_dtype(self, func, use_nullable_dtypes):
2400+
# GH#50797
2401+
table = "test"
2402+
df = DataFrame({"a": [1, 2, 3], "b": 5})
2403+
df.to_sql(table, self.conn, index=False, if_exists="replace")
2404+
2405+
result = getattr(pd, func)(
2406+
f"Select * from {table}",
2407+
self.conn,
2408+
dtype={"a": np.float64},
2409+
use_nullable_dtypes=use_nullable_dtypes,
2410+
)
2411+
expected = DataFrame(
2412+
{
2413+
"a": Series([1, 2, 3], dtype=np.float64),
2414+
"b": Series(
2415+
[5, 5, 5], dtype="int64" if not use_nullable_dtypes else "Int64"
2416+
),
2417+
}
2418+
)
2419+
tm.assert_frame_equal(result, expected)
2420+
23972421

23982422
class TestSQLiteAlchemy(_TestSQLAlchemy):
23992423
"""

0 commit comments

Comments
 (0)