Skip to content

Commit a297a51

Browse files
author
Chuck Cadman
committed
Make PandasSQL.execute arguments more precise.
1 parent b87dc7c commit a297a51

File tree

2 files changed

+49
-26
lines changed

2 files changed

+49
-26
lines changed

pandas/io/sql.py

+18-26
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@
6666

6767
if TYPE_CHECKING:
6868
from sqlalchemy import Table
69+
from sqlalchemy.sql.expression import (
70+
Select,
71+
TextClause,
72+
)
6973

7074

7175
# -----------------------------------------------------------------------------
@@ -80,17 +84,6 @@ def _cleanup_after_generator(generator, exit_stack: ExitStack):
8084
exit_stack.close()
8185

8286

83-
def _convert_params(sql, params):
84-
"""Convert SQL and params args to DBAPI2.0 compliant format."""
85-
args = [sql]
86-
if params is not None:
87-
if hasattr(params, "keys"): # test if params is a mapping
88-
args += [params]
89-
else:
90-
args += [list(params)]
91-
return args
92-
93-
9487
def _process_parse_dates_argument(parse_dates):
9588
"""Process parse_dates argument for read_sql functions"""
9689
# handle non-list entries for parse_dates gracefully
@@ -217,8 +210,7 @@ def execute(sql, con, params=None):
217210
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)):
218211
raise TypeError("pandas.io.sql.execute requires a connection") # GH50185
219212
with pandasSQL_builder(con, need_transaction=True) as pandas_sql:
220-
args = _convert_params(sql, params)
221-
return pandas_sql.execute(*args)
213+
return pandas_sql.execute(sql, params)
222214

223215

224216
# -----------------------------------------------------------------------------
@@ -1391,7 +1383,7 @@ def to_sql(
13911383
pass
13921384

13931385
@abstractmethod
1394-
def execute(self, *args, **kwargs):
1386+
def execute(self, sql: str | Select | TextClause, params=None):
13951387
pass
13961388

13971389
@abstractmethod
@@ -1538,9 +1530,10 @@ def __exit__(self, *args) -> None:
15381530
def run_transaction(self):
15391531
yield self.con
15401532

1541-
def execute(self, *args, **kwargs):
1533+
def execute(self, sql: str | Select | TextClause, params=None):
15421534
"""Simple passthrough to SQLAlchemy connectable"""
1543-
return self.con.execute(*args, **kwargs)
1535+
args = [] if params is None else [params]
1536+
return self.con.execute(sql, *args)
15441537

15451538
def read_table(
15461539
self,
@@ -1710,9 +1703,7 @@ def read_query(
17101703
read_sql
17111704
17121705
"""
1713-
args = _convert_params(sql, params)
1714-
1715-
result = self.execute(*args)
1706+
result = self.execute(sql, params)
17161707
columns = result.keys()
17171708

17181709
if chunksize is not None:
@@ -2170,21 +2161,24 @@ def run_transaction(self):
21702161
finally:
21712162
cur.close()
21722163

2173-
def execute(self, *args, **kwargs):
2164+
def execute(self, sql: str | Select | TextClause, params=None):
2165+
if not isinstance(sql, str):
2166+
raise TypeError("Query must be a string unless using sqlalchemy.")
2167+
args = [] if params is None else [params]
21742168
cur = self.con.cursor()
21752169
try:
2176-
cur.execute(*args, **kwargs)
2170+
cur.execute(sql, *args)
21772171
return cur
21782172
except Exception as exc:
21792173
try:
21802174
self.con.rollback()
21812175
except Exception as inner_exc: # pragma: no cover
21822176
ex = DatabaseError(
2183-
f"Execution failed on sql: {args[0]}\n{exc}\nunable to rollback"
2177+
f"Execution failed on sql: {sql}\n{exc}\nunable to rollback"
21842178
)
21852179
raise ex from inner_exc
21862180

2187-
ex = DatabaseError(f"Execution failed on sql '{args[0]}': {exc}")
2181+
ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}")
21882182
raise ex from exc
21892183

21902184
@staticmethod
@@ -2237,9 +2231,7 @@ def read_query(
22372231
dtype: DtypeArg | None = None,
22382232
use_nullable_dtypes: bool = False,
22392233
) -> DataFrame | Iterator[DataFrame]:
2240-
2241-
args = _convert_params(sql, params)
2242-
cursor = self.execute(*args)
2234+
cursor = self.execute(sql, params)
22432235
columns = [col_desc[0] for col_desc in cursor.description]
22442236

22452237
if chunksize is not None:

pandas/tests/io/test_sql.py

+31
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,37 @@ def test_read_iris_query_chunksize(conn, request):
595595
assert "SepalWidth" in iris_frame.columns
596596

597597

598+
@pytest.mark.db
599+
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
600+
def test_read_iris_query_expression_with_parameter(conn, request):
601+
conn = request.getfixturevalue(conn)
602+
from sqlalchemy import (
603+
MetaData,
604+
Table,
605+
create_engine,
606+
select,
607+
)
608+
609+
metadata = MetaData()
610+
autoload_con = create_engine(conn) if isinstance(conn, str) else conn
611+
iris = Table("iris", metadata, autoload_with=autoload_con)
612+
iris_frame = read_sql_query(
613+
select(iris), conn, params={"name": "Iris-setosa", "length": 5.1}
614+
)
615+
check_iris_frame(iris_frame)
616+
617+
618+
@pytest.mark.db
619+
@pytest.mark.parametrize("conn", all_connectable_iris)
620+
def test_read_iris_query_string_with_parameter(conn, request):
621+
for db, query in SQL_STRINGS["read_parameters"].items():
622+
if db in conn:
623+
break
624+
conn = request.getfixturevalue(conn)
625+
iris_frame = read_sql_query(query, conn, params=("Iris-setosa", 5.1))
626+
check_iris_frame(iris_frame)
627+
628+
598629
@pytest.mark.db
599630
@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris)
600631
def test_read_iris_table(conn, request):

0 commit comments

Comments
 (0)