diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index d6e0bb2ae0830..2db5f977721d8 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -37,6 +37,7 @@ The ``use_nullable_dtypes`` keyword argument has been expanded to the following * :func:`read_csv` * :func:`read_excel` +* :func:`read_sql` Additionally a new global configuration, ``io.nullable_backend`` can now be used in conjunction with the parameter ``use_nullable_dtypes=True`` in the following functions to select the nullable dtypes implementation. diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index 07fab0080a747..9bdfd7991689b 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -31,9 +31,11 @@ ) from pandas.core.dtypes.common import ( is_1d_only_ea_dtype, + is_bool_dtype, is_datetime_or_timedelta_dtype, is_dtype_equal, is_extension_array_dtype, + is_float_dtype, is_integer_dtype, is_list_like, is_named_tuple, @@ -49,7 +51,13 @@ algorithms, common as com, ) -from pandas.core.arrays import ExtensionArray +from pandas.core.arrays import ( + BooleanArray, + ExtensionArray, + FloatingArray, + IntegerArray, +) +from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import ( ensure_wrapped_if_datetimelike, extract_array, @@ -900,7 +908,7 @@ def _finalize_columns_and_data( raise ValueError(err) from err if len(contents) and contents[0].dtype == np.object_: - contents = _convert_object_array(contents, dtype=dtype) + contents = convert_object_array(contents, dtype=dtype) return contents, columns @@ -963,8 +971,11 @@ def _validate_or_indexify_columns( return columns -def _convert_object_array( - content: list[npt.NDArray[np.object_]], dtype: DtypeObj | None +def convert_object_array( + content: list[npt.NDArray[np.object_]], + dtype: DtypeObj | None, + use_nullable_dtypes: bool = False, + coerce_float: bool = False, ) -> list[ArrayLike]: """ Internal function to convert object array. @@ -973,20 +984,37 @@ def _convert_object_array( ---------- content: List[np.ndarray] dtype: np.dtype or ExtensionDtype + use_nullable_dtypes: Controls if nullable dtypes are returned. + coerce_float: Cast floats that are integers to int. Returns ------- List[ArrayLike] """ # provide soft conversion of object dtypes + def convert(arr): if dtype != np.dtype("O"): - arr = lib.maybe_convert_objects(arr) + arr = lib.maybe_convert_objects( + arr, + try_float=coerce_float, + convert_to_nullable_dtype=use_nullable_dtypes, + ) if dtype is None: if arr.dtype == np.dtype("O"): # i.e. maybe_convert_objects didn't convert arr = maybe_infer_to_datetimelike(arr) + if use_nullable_dtypes and arr.dtype == np.dtype("O"): + arr = StringDtype().construct_array_type()._from_sequence(arr) + elif use_nullable_dtypes and isinstance(arr, np.ndarray): + if is_integer_dtype(arr.dtype): + arr = IntegerArray(arr, np.zeros(arr.shape, dtype=np.bool_)) + elif is_bool_dtype(arr.dtype): + arr = BooleanArray(arr, np.zeros(arr.shape, dtype=np.bool_)) + elif is_float_dtype(arr.dtype): + arr = FloatingArray(arr, np.isnan(arr)) + elif isinstance(dtype, ExtensionDtype): # TODO: test(s) that get here # TODO: try to de-duplicate this convert function with diff --git a/pandas/io/sql.py b/pandas/io/sql.py index e3510f71bd0cd..4c1dca180c6e9 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -58,6 +58,7 @@ ) from pandas.core.base import PandasObject import pandas.core.common as com +from pandas.core.internals.construction import convert_object_array from pandas.core.tools.datetimes import to_datetime if TYPE_CHECKING: @@ -139,6 +140,25 @@ def _parse_date_columns(data_frame, parse_dates): return data_frame +def _convert_arrays_to_dataframe( + data, + columns, + coerce_float: bool = True, + use_nullable_dtypes: bool = False, +) -> DataFrame: + content = lib.to_object_array_tuples(data) + arrays = convert_object_array( + list(content.T), + dtype=None, + coerce_float=coerce_float, + use_nullable_dtypes=use_nullable_dtypes, + ) + if arrays: + return DataFrame(dict(zip(columns, arrays))) + else: + return DataFrame(columns=columns) + + def _wrap_result( data, columns, @@ -146,9 +166,12 @@ def _wrap_result( coerce_float: bool = True, parse_dates=None, dtype: DtypeArg | None = None, + use_nullable_dtypes: bool = False, ): """Wrap result set of query in a DataFrame.""" - frame = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float) + frame = _convert_arrays_to_dataframe( + data, columns, coerce_float, use_nullable_dtypes + ) if dtype: frame = frame.astype(dtype) @@ -156,7 +179,7 @@ def _wrap_result( frame = _parse_date_columns(frame, parse_dates) if index_col is not None: - frame.set_index(index_col, inplace=True) + frame = frame.set_index(index_col) return frame @@ -418,6 +441,7 @@ def read_sql( parse_dates=..., columns: list[str] = ..., chunksize: None = ..., + use_nullable_dtypes: bool = ..., ) -> DataFrame: ... @@ -432,6 +456,7 @@ def read_sql( parse_dates=..., columns: list[str] = ..., chunksize: int = ..., + use_nullable_dtypes: bool = ..., ) -> Iterator[DataFrame]: ... @@ -445,6 +470,7 @@ def read_sql( parse_dates=None, columns: list[str] | None = None, chunksize: int | None = None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: """ Read SQL query or database table into a DataFrame. @@ -492,6 +518,12 @@ def read_sql( chunksize : int, default None If specified, return an iterator where `chunksize` is the number of rows to include in each chunk. + use_nullable_dtypes : bool = False + Whether to use nullable dtypes as default when reading data. If + set to True, nullable dtypes are used for all dtypes that have a nullable + implementation, even if no nulls are present. + + .. versionadded:: 2.0 Returns ------- @@ -571,6 +603,7 @@ def read_sql( coerce_float=coerce_float, parse_dates=parse_dates, chunksize=chunksize, + use_nullable_dtypes=use_nullable_dtypes, ) try: @@ -587,6 +620,7 @@ def read_sql( parse_dates=parse_dates, columns=columns, chunksize=chunksize, + use_nullable_dtypes=use_nullable_dtypes, ) else: return pandas_sql.read_query( @@ -596,6 +630,7 @@ def read_sql( coerce_float=coerce_float, parse_dates=parse_dates, chunksize=chunksize, + use_nullable_dtypes=use_nullable_dtypes, ) @@ -983,6 +1018,7 @@ def _query_iterator( columns, coerce_float: bool = True, parse_dates=None, + use_nullable_dtypes: bool = False, ): """Return generator through chunked result set.""" has_read_data = False @@ -996,11 +1032,13 @@ def _query_iterator( break has_read_data = True - self.frame = DataFrame.from_records( - data, columns=columns, coerce_float=coerce_float + self.frame = _convert_arrays_to_dataframe( + data, columns, coerce_float, use_nullable_dtypes ) - self._harmonize_columns(parse_dates=parse_dates) + self._harmonize_columns( + parse_dates=parse_dates, use_nullable_dtypes=use_nullable_dtypes + ) if self.index is not None: self.frame.set_index(self.index, inplace=True) @@ -1013,6 +1051,7 @@ def read( parse_dates=None, columns=None, chunksize=None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: from sqlalchemy import select @@ -1034,14 +1073,17 @@ def read( column_names, coerce_float=coerce_float, parse_dates=parse_dates, + use_nullable_dtypes=use_nullable_dtypes, ) else: data = result.fetchall() - self.frame = DataFrame.from_records( - data, columns=column_names, coerce_float=coerce_float + self.frame = _convert_arrays_to_dataframe( + data, column_names, coerce_float, use_nullable_dtypes ) - self._harmonize_columns(parse_dates=parse_dates) + self._harmonize_columns( + parse_dates=parse_dates, use_nullable_dtypes=use_nullable_dtypes + ) if self.index is not None: self.frame.set_index(self.index, inplace=True) @@ -1124,7 +1166,9 @@ def _create_table_setup(self): meta = MetaData() return Table(self.name, meta, *columns, schema=schema) - def _harmonize_columns(self, parse_dates=None) -> None: + def _harmonize_columns( + self, parse_dates=None, use_nullable_dtypes: bool = False + ) -> None: """ Make the DataFrame's column types align with the SQL table column types. @@ -1164,11 +1208,11 @@ def _harmonize_columns(self, parse_dates=None) -> None: # Convert tz-aware Datetime SQL columns to UTC utc = col_type is DatetimeTZDtype self.frame[col_name] = _handle_date_column(df_col, utc=utc) - elif col_type is float: + elif not use_nullable_dtypes and col_type is float: # floats support NA, can always convert! self.frame[col_name] = df_col.astype(col_type, copy=False) - elif len(df_col) == df_col.count(): + elif not use_nullable_dtypes and len(df_col) == df_col.count(): # No NA values, can convert ints and bools if col_type is np.dtype("int64") or col_type is bool: self.frame[col_name] = df_col.astype(col_type, copy=False) @@ -1290,6 +1334,7 @@ def read_table( columns=None, schema: str | None = None, chunksize: int | None = None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: raise NotImplementedError @@ -1303,6 +1348,7 @@ def read_query( params=None, chunksize: int | None = None, dtype: DtypeArg | None = None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: pass @@ -1466,6 +1512,7 @@ def read_table( columns=None, schema: str | None = None, chunksize: int | None = None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: """ Read SQL database table into a DataFrame. @@ -1498,6 +1545,12 @@ def read_table( chunksize : int, default None If specified, return an iterator where `chunksize` is the number of rows to include in each chunk. + use_nullable_dtypes : bool = False + Whether to use nullable dtypes as default when reading data. If + set to True, nullable dtypes are used for all dtypes that have a nullable + implementation, even if no nulls are present. + + .. versionadded:: 2.0 Returns ------- @@ -1516,6 +1569,7 @@ def read_table( parse_dates=parse_dates, columns=columns, chunksize=chunksize, + use_nullable_dtypes=use_nullable_dtypes, ) @staticmethod @@ -1527,6 +1581,7 @@ def _query_iterator( coerce_float: bool = True, parse_dates=None, dtype: DtypeArg | None = None, + use_nullable_dtypes: bool = False, ): """Return generator through chunked result set""" has_read_data = False @@ -1540,6 +1595,7 @@ def _query_iterator( index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, + use_nullable_dtypes=use_nullable_dtypes, ) break @@ -1551,6 +1607,7 @@ def _query_iterator( coerce_float=coerce_float, parse_dates=parse_dates, dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) def read_query( @@ -1562,6 +1619,7 @@ def read_query( params=None, chunksize: int | None = None, dtype: DtypeArg | None = None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: """ Read SQL query into a DataFrame. @@ -1623,6 +1681,7 @@ def read_query( coerce_float=coerce_float, parse_dates=parse_dates, dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) else: data = result.fetchall() @@ -1633,6 +1692,7 @@ def read_query( coerce_float=coerce_float, parse_dates=parse_dates, dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) return frame @@ -2089,6 +2149,7 @@ def _query_iterator( coerce_float: bool = True, parse_dates=None, dtype: DtypeArg | None = None, + use_nullable_dtypes: bool = False, ): """Return generator through chunked result set""" has_read_data = False @@ -2112,6 +2173,7 @@ def _query_iterator( coerce_float=coerce_float, parse_dates=parse_dates, dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) def read_query( @@ -2123,6 +2185,7 @@ def read_query( params=None, chunksize: int | None = None, dtype: DtypeArg | None = None, + use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: args = _convert_params(sql, params) @@ -2138,6 +2201,7 @@ def read_query( coerce_float=coerce_float, parse_dates=parse_dates, dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) else: data = self._fetchall_as_list(cursor) @@ -2150,6 +2214,7 @@ def read_query( coerce_float=coerce_float, parse_dates=parse_dates, dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) return frame diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 394fceb69b788..db37b1785af5c 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -53,6 +53,10 @@ to_timedelta, ) import pandas._testing as tm +from pandas.core.arrays import ( + ArrowStringArray, + StringArray, +) from pandas.io import sql from pandas.io.sql import ( @@ -2266,6 +2270,94 @@ def test_get_engine_auto_error_message(self): pass # TODO(GH#36893) fill this in when we add more engines + @pytest.mark.parametrize("storage", ["pyarrow", "python"]) + def test_read_sql_nullable_dtypes(self, storage): + # GH#50048 + table = "test" + df = self.nullable_data() + df.to_sql(table, self.conn, index=False, if_exists="replace") + + with pd.option_context("mode.string_storage", storage): + result = pd.read_sql( + f"Select * from {table}", self.conn, use_nullable_dtypes=True + ) + expected = self.nullable_expected(storage) + tm.assert_frame_equal(result, expected) + + with pd.option_context("mode.string_storage", storage): + iterator = pd.read_sql( + f"Select * from {table}", + self.conn, + use_nullable_dtypes=True, + chunksize=3, + ) + expected = self.nullable_expected(storage) + for result in iterator: + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("storage", ["pyarrow", "python"]) + def test_read_sql_nullable_dtypes_table(self, storage): + # GH#50048 + table = "test" + df = self.nullable_data() + df.to_sql(table, self.conn, index=False, if_exists="replace") + + with pd.option_context("mode.string_storage", storage): + result = pd.read_sql(table, self.conn, use_nullable_dtypes=True) + expected = self.nullable_expected(storage) + tm.assert_frame_equal(result, expected) + + with pd.option_context("mode.string_storage", storage): + iterator = pd.read_sql( + f"Select * from {table}", + self.conn, + use_nullable_dtypes=True, + chunksize=3, + ) + expected = self.nullable_expected(storage) + for result in iterator: + tm.assert_frame_equal(result, expected) + + def nullable_data(self) -> DataFrame: + return DataFrame( + { + "a": Series([1, np.nan, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, np.nan, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": [True, False, None], + "f": [True, False, True], + "g": ["a", "b", "c"], + "h": ["a", "b", None], + } + ) + + def nullable_expected(self, storage) -> DataFrame: + + string_array: StringArray | ArrowStringArray + string_array_na: StringArray | ArrowStringArray + if storage == "python": + string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_)) + string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_)) + + else: + pa = pytest.importorskip("pyarrow") + string_array = ArrowStringArray(pa.array(["a", "b", "c"])) + string_array_na = ArrowStringArray(pa.array(["a", "b", None])) + + return DataFrame( + { + "a": Series([1, np.nan, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, np.nan, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": Series([True, False, pd.NA], dtype="boolean"), + "f": Series([True, False, True], dtype="boolean"), + "g": string_array, + "h": string_array_na, + } + ) + class TestSQLiteAlchemy(_TestSQLAlchemy): """ @@ -2349,6 +2441,14 @@ class Test(BaseModel): assert list(df.columns) == ["id", "string_column"] + def nullable_expected(self, storage) -> DataFrame: + return super().nullable_expected(storage).astype({"e": "Int64", "f": "Int64"}) + + @pytest.mark.parametrize("storage", ["pyarrow", "python"]) + def test_read_sql_nullable_dtypes_table(self, storage): + # GH#50048 Not supported for sqlite + pass + @pytest.mark.db class TestMySQLAlchemy(_TestSQLAlchemy): @@ -2376,6 +2476,9 @@ def setup_driver(cls): def test_default_type_conversion(self): pass + def nullable_expected(self, storage) -> DataFrame: + return super().nullable_expected(storage).astype({"e": "Int64", "f": "Int64"}) + @pytest.mark.db class TestPostgreSQLAlchemy(_TestSQLAlchemy):