Skip to content

Commit e1bc42a

Browse files
Add typing for io/sql.py
1 parent 0805043 commit e1bc42a

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

pandas/io/sql.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def read_sql_query(
383383
Data type for data or columns. E.g. np.float64 or
384384
{‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’}
385385
386+
.. versionadded:: 1.3.0
387+
386388
Returns
387389
-------
388390
DataFrame or Iterator[DataFrame]
@@ -609,7 +611,7 @@ def to_sql(
609611
index: bool = True,
610612
index_label=None,
611613
chunksize: Optional[int] = None,
612-
dtype=None,
614+
dtype: Optional[DtypeArg] = None,
613615
method: Optional[str] = None,
614616
) -> None:
615617
"""
@@ -768,7 +770,7 @@ def __init__(
768770
index_label=None,
769771
schema=None,
770772
keys=None,
771-
dtype=None,
773+
dtype: Optional[DtypeArg] = None,
772774
):
773775
self.name = name
774776
self.pd_sql = pandas_sql_engine
@@ -1108,9 +1110,9 @@ def _harmonize_columns(self, parse_dates=None):
11081110

11091111
def _sqlalchemy_type(self, col):
11101112

1111-
dtype = self.dtype or {}
1112-
if col.name in dtype:
1113-
return self.dtype[col.name]
1113+
dtype: DtypeArg = self.dtype or {}
1114+
if isinstance(dtype, dict) and col.name in dtype:
1115+
return dtype[col.name]
11141116

11151117
# Infer type of column, while ignoring missing values.
11161118
# Needed for inserting typed data containing NULLs, GH 8778.
@@ -1203,7 +1205,18 @@ def read_sql(self, *args, **kwargs):
12031205
"connectable or sqlite connection"
12041206
)
12051207

1206-
def to_sql(self, *args, **kwargs):
1208+
def to_sql(
1209+
self,
1210+
frame,
1211+
name,
1212+
if_exists="fail",
1213+
index=True,
1214+
index_label=None,
1215+
schema=None,
1216+
chunksize=None,
1217+
dtype: Optional[DtypeArg] = None,
1218+
method=None,
1219+
):
12071220
raise ValueError(
12081221
"PandasSQL must be created with an SQLAlchemy "
12091222
"connectable or sqlite connection"
@@ -1430,7 +1443,7 @@ def to_sql(
14301443
index_label=None,
14311444
schema=None,
14321445
chunksize=None,
1433-
dtype=None,
1446+
dtype: Optional[DtypeArg] = None,
14341447
method=None,
14351448
):
14361449
"""
@@ -1477,7 +1490,7 @@ def to_sql(
14771490
if dtype and not is_dict_like(dtype):
14781491
dtype = {col_name: dtype for col_name in frame}
14791492

1480-
if dtype is not None:
1493+
if dtype is not None and isinstance(dtype, dict):
14811494
from sqlalchemy.types import TypeEngine, to_instance
14821495

14831496
for col, my_type in dtype.items():
@@ -1563,7 +1576,7 @@ def _create_sql_schema(
15631576
frame: DataFrame,
15641577
table_name: str,
15651578
keys: Optional[List[str]] = None,
1566-
dtype: Optional[dict] = None,
1579+
dtype: Optional[DtypeArg] = None,
15671580
schema: Optional[str] = None,
15681581
):
15691582
table = SQLTable(
@@ -1734,8 +1747,8 @@ def _create_table_setup(self):
17341747
return create_stmts
17351748

17361749
def _sql_type_name(self, col):
1737-
dtype = self.dtype or {}
1738-
if col.name in dtype:
1750+
dtype: DtypeArg = self.dtype or {}
1751+
if isinstance(dtype, dict) and col.name in dtype:
17391752
return dtype[col.name]
17401753

17411754
# Infer type of column, while ignoring missing values.
@@ -1895,7 +1908,7 @@ def to_sql(
18951908
index_label=None,
18961909
schema=None,
18971910
chunksize=None,
1898-
dtype=None,
1911+
dtype: Optional[DtypeArg] = None,
18991912
method=None,
19001913
):
19011914
"""
@@ -1941,7 +1954,7 @@ def to_sql(
19411954
if dtype and not is_dict_like(dtype):
19421955
dtype = {col_name: dtype for col_name in frame}
19431956

1944-
if dtype is not None:
1957+
if dtype is not None and isinstance(dtype, dict):
19451958
for col, my_type in dtype.items():
19461959
if not isinstance(my_type, str):
19471960
raise ValueError(f"{col} ({my_type}) not a string")
@@ -1980,7 +1993,7 @@ def _create_sql_schema(
19801993
frame,
19811994
table_name: str,
19821995
keys=None,
1983-
dtype=None,
1996+
dtype: Optional[DtypeArg] = None,
19841997
schema: Optional[str] = None,
19851998
):
19861999
table = SQLiteTable(
@@ -1996,7 +2009,12 @@ def _create_sql_schema(
19962009

19972010

19982011
def get_schema(
1999-
frame, name: str, keys=None, con=None, dtype=None, schema: Optional[str] = None
2012+
frame,
2013+
name: str,
2014+
keys=None,
2015+
con=None,
2016+
dtype: Optional[DtypeArg] = None,
2017+
schema: Optional[str] = None,
20002018
):
20012019
"""
20022020
Get the SQL db table schema for the given frame.

0 commit comments

Comments
 (0)