From d9e99ec5a078d41625f8f5c4716555f78051bc3d Mon Sep 17 00:00:00 2001 From: Fangchen Li Date: Tue, 17 Aug 2021 12:02:31 -0500 Subject: [PATCH] TST: refactor drop_table in sql test --- pandas/tests/io/test_sql.py | 46 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 6136bd0e1e057..d6ace9a997951 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -249,6 +249,24 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str): conn.execute(stmt) +def count_rows(conn, table_name: str): + stmt = f"SELECT count(*) AS count_1 FROM {table_name}" + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + result = cur.execute(stmt) + else: + from sqlalchemy import text + from sqlalchemy.engine import Engine + + stmt = text(stmt) + if isinstance(conn, Engine): + with conn.connect() as conn: + result = conn.execute(stmt) + else: + result = conn.execute(stmt) + return result.fetchone()[0] + + @pytest.fixture def iris_path(datapath): iris_path = datapath("io", "data", "csv", "iris.csv") @@ -415,12 +433,6 @@ class PandasSQLTest: """ - def _get_exec(self): - if hasattr(self.conn, "execute"): - return self.conn - else: - return self.conn.cursor() - @pytest.fixture def load_iris_data(self, iris_path): if not hasattr(self, "conn"): @@ -451,14 +463,6 @@ def _check_iris_loaded_frame(self, iris_frame): assert issubclass(pytype, np.floating) tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]) - def _count_rows(self, table_name): - result = ( - self._get_exec() - .execute(f"SELECT count(*) AS count_1 FROM {table_name}") - .fetchone() - ) - return result[0] - def _read_sql_iris(self): iris_frame = self.pandasSQL.read_query("SELECT * FROM iris") self._check_iris_loaded_frame(iris_frame) @@ -487,7 +491,7 @@ def _to_sql(self, test_frame1, method=None): assert self.pandasSQL.has_table("test_frame1") num_entries = len(test_frame1) - num_rows = self._count_rows("test_frame1") + num_rows = count_rows(self.conn, "test_frame1") assert num_rows == num_entries # Nuke table @@ -518,7 +522,7 @@ def _to_sql_replace(self, test_frame1): assert self.pandasSQL.has_table("test_frame1") num_entries = len(test_frame1) - num_rows = self._count_rows("test_frame1") + num_rows = count_rows(self.conn, "test_frame1") assert num_rows == num_entries self.drop_table("test_frame1") @@ -534,7 +538,7 @@ def _to_sql_append(self, test_frame1): assert self.pandasSQL.has_table("test_frame1") num_entries = 2 * len(test_frame1) - num_rows = self._count_rows("test_frame1") + num_rows = count_rows(self.conn, "test_frame1") assert num_rows == num_entries self.drop_table("test_frame1") @@ -554,7 +558,7 @@ def sample(pd_table, conn, keys, data_iter): assert check == [1] num_entries = len(test_frame1) - num_rows = self._count_rows("test_frame1") + num_rows = count_rows(self.conn, "test_frame1") assert num_rows == num_entries # Nuke table self.drop_table("test_frame1") @@ -570,7 +574,7 @@ def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): assert self.pandasSQL.has_table("test_frame1") num_entries = len(test_frame1) - num_rows = self._count_rows("test_frame1") + num_rows = count_rows(self.conn, "test_frame1") assert num_rows == num_entries # Nuke table @@ -695,7 +699,7 @@ def test_to_sql_replace(self, test_frame1): assert sql.has_table("test_frame3", self.conn) num_entries = len(test_frame1) - num_rows = self._count_rows("test_frame3") + num_rows = count_rows(self.conn, "test_frame3") assert num_rows == num_entries @@ -707,7 +711,7 @@ def test_to_sql_append(self, test_frame1): assert sql.has_table("test_frame4", self.conn) num_entries = 2 * len(test_frame1) - num_rows = self._count_rows("test_frame4") + num_rows = count_rows(self.conn, "test_frame4") assert num_rows == num_entries