diff --git a/pandas/core/generic.py b/pandas/core/generic.py index b235f120d98c8..813012a835d5c 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2819,12 +2819,15 @@ def to_sql( schema : str, optional Specify the schema (if database flavor supports this). If None, use default schema. - if_exists : {'fail', 'replace', 'append'}, default 'fail' + if_exists : {'fail', 'replace', 'append', 'upsert_overwrite', 'upsert_keep'},\ + default 'fail' How to behave if the table already exists. * fail: Raise a ValueError. * replace: Drop the table before inserting new values. * append: Insert new values to the existing table. + * upsert_overwrite: Overwrite matches in database with incoming data. + * upsert_keep: Keep matches in database instead of incoming data. index : bool, default True Write DataFrame index as a column. Uses `index_label` as the column diff --git a/pandas/io/sql.py b/pandas/io/sql.py index ec5262ee3a04c..62f781d9a9ff6 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -602,6 +602,7 @@ def to_sql( con, schema: str | None = None, if_exists: str = "fail", + on_conflict: str | None = None, index: bool = True, index_label=None, chunksize: int | None = None, @@ -626,10 +627,18 @@ def to_sql( schema : str, optional Name of SQL schema in database to write to (if database flavor supports this). If None, use default schema (default). - if_exists : {'fail', 'replace', 'append'}, default 'fail' + if_exists : {'fail', 'replace', 'append'}, + default 'fail'. - fail: If table exists, do nothing. - replace: If table exists, drop it, recreate it, and insert data. - append: If table exists, insert data. Create if does not exist. + on_conflict : {None, 'do_nothing', 'do_update'}, optional + Determine insertion behaviour in case of a primary key clash. + - None: Do nothing to handle primary key clashes, will raise an Error. + - 'do_nothing': Ignore incoming rows with primary key clashes, and + insert only the incoming rows with non-conflicting primary keys + - 'do_update': Update existing rows in database with primary key clashes, + and append the remaining rows with non-conflicting primary keys index : bool, default True Write DataFrame index as a column. index_label : str or sequence, optional @@ -663,9 +672,21 @@ def to_sql( **engine_kwargs Any additional kwargs are passed to the engine. """ - if if_exists not in ("fail", "replace", "append"): + if if_exists not in ( + "fail", + "replace", + "append", + ): raise ValueError(f"'{if_exists}' is not valid for if_exists") + if on_conflict: + # on_conflict argument is valid + if on_conflict not in ("do_update", "do_nothing"): + raise ValueError(f"'{on_conflict}' is not valid for on_conflict'") + # on_conflict only used with append + elif if_exists != "append": + raise ValueError("on_conflict can only be used with 'append' operations") + pandas_sql = pandasSQL_builder(con, schema=schema) if isinstance(frame, Series): @@ -679,6 +700,7 @@ def to_sql( frame, name, if_exists=if_exists, + on_conflict=on_conflict, index=index, index_label=index_label, schema=schema, @@ -759,6 +781,7 @@ def __init__( frame=None, index=True, if_exists="fail", + on_conflict=None, prefix="pandas", index_label=None, schema=None, @@ -772,6 +795,7 @@ def __init__( self.index = self._index_name(index, index_label) self.schema = schema self.if_exists = if_exists + self.on_conflict = on_conflict self.keys = keys self.dtype = dtype @@ -815,6 +839,184 @@ def create(self): else: self._execute_create() + def _load_existing_pkeys(self, primary_keys, primary_key_values): + """ + Load existing primary keys from Database + + Parameters + ---------- + primary_keys : list of str + List of primary key column names + primary_key_values : list of str + List of primary key values already present in incoming dataframe + + Returns + ------- + list of str + primary key values in incoming dataframe which already exist in database + """ + from sqlalchemy import ( + and_, + select, + ) + + cols_to_fetch = [self.table.c[key] for key in primary_keys] + # select_stmt = select(cols_to_fetch).where( + # tuple_(*cols_to_fetch).in_(primary_key_values) + # ) + select_stmt = select(cols_to_fetch).where( + and_( + col.in_(key[i] for key in primary_key_values) + for i, col in enumerate(cols_to_fetch) + ) + ) + return self.pd_sql.execute(select_stmt).fetchall() + + def _split_incoming_data(self, primary_keys, keys_in_db): + """ + Split incoming dataframe based off whether primary key already exists in db. + + Parameters + ---------- + primary_keys : list of str + Primary keys columns + keys_in_db : list of str + Primary key values which already exist in database table + + Returns + ------- + tuple of DataFrame, DataFrame + DataFrame of rows with duplicate pkey, DataFrame of rows with new pkey + """ + from pandas.core.indexes.multi import MultiIndex + + in_db = _wrap_result(data=keys_in_db, columns=primary_keys) + # Get temporary dataframe so as not to delete values from main df + temp = self._get_index_formatted_dataframe() + # Create multi-indexes for membership lookup + in_db_idx = MultiIndex.from_arrays([in_db[col] for col in primary_keys]) + tmp_idx = MultiIndex.from_arrays([temp[col] for col in primary_keys]) + exists_mask = tmp_idx.isin(in_db_idx) + return temp.loc[exists_mask], temp.loc[~exists_mask] + + def _generate_update_statements(self, primary_keys, keys_in_db, rows_to_update): + """ + Generate SQL Update statements for rows with existing primary keys + + Currently, SQL Update statements do not support a multi-statement query, + therefore this method returns a list of individual update queries which + will need to be executed in one transaction. + + Parameters + ---------- + primary_keys : list of str + Primary key columns + keys_in_db : list of str + Primary key values which already exist in database table + rows_to_update : DataFrame + DataFrame of rows containing data with which to update existing pkeys + + Returns + ------- + list of sqlalchemy.sql.dml.Update + List of update queries + """ + from sqlalchemy import and_ + + new_records = rows_to_update.to_dict(orient="records") + pk_cols = [self.table.c[key] for key in primary_keys] + + # TODO: Move this or remove entirely + assert len(new_records) == len( + keys_in_db + ), "Mismatch between new records and existing keys" + stmts = [] + for i, keys in enumerate(keys_in_db): + stmt = ( + self.table.update() + .where(and_(col == keys[j] for j, col in enumerate(pk_cols))) + .values(new_records[i]) + ) + stmts.append(stmt) + return stmts + + def _on_conflict_do_update(self): + """ + Generate update statements for rows with clashing primary key from database. + + `on_conflict do_update` prioritizes incoming data, over existing data in the DB. + This method splits the incoming dataframe between rows with new and existing + primary key values. + For existing values Update statements are generated, while new values are passed + on to be inserted as usual. + + Updates are executed in the same transaction as the ensuing data insert. + + Returns + ---------- + sqlalchemy.sql.dml.Delete + Delete statement to be executed against DB + """ + # Primary key data + pk_cols, pk_values = self._get_primary_key_data() + existing_keys = self._load_existing_pkeys(pk_cols, pk_values) + existing_data, new_data = self._split_incoming_data(pk_cols, existing_keys) + update_stmts = self._generate_update_statements( + pk_cols, existing_keys, existing_data + ) + + return new_data, update_stmts + + def _on_conflict_do_nothing(self): + """ + Split incoming dataframe so that only rows with new primary keys are inserted + + `on_conflict` set to `do_nothing` prioritizes existing data in the DB. + This method identifies incoming records in the primary key columns + which correspond to existing primary key constraints in the db table, and + avoids them from being inserted. + """ + pk_cols, pk_values = self._get_primary_key_data() + existing_keys = self._load_existing_pkeys(pk_cols, pk_values) + _, new_data = self._split_incoming_data(pk_cols, existing_keys) + return new_data + + def _get_primary_key_data(self): + """ + Get primary keys from database, and yield dataframe columns with same names. + + Upsert workflows require knowledge of what is already in the database. + This method reflects the meta object and gets a list of primary keys, + it then returns all columns from the incoming dataframe with names matching + these keys. + + Returns + ------- + primary_keys : list of str + Primary key names + primary_key_values : list of str + DataFrame rows, for columns corresponding to `primary_key` names + """ + # reflect MetaData object and assign contents of db to self.table attribute + bind = None + if not self.pd_sql.meta.is_bound(): + bind = self.pd_sql.connectable + self.pd_sql.meta.reflect(bind=bind, only=[self.name], views=True) + self.table = self.pd_sql.get_table(table_name=self.name, schema=self.schema) + + primary_keys = self.table.primary_key.columns.keys() + + # For the time being, this method is defensive and will break if + # no pkeys are found. If desired this default behaviour could be + # changed so that in cases where no pkeys are found, + # it could default to a normal insert + if len(primary_keys) == 0: + raise ValueError(f"No primary keys found for table {self.name}") + + temp = self._get_index_formatted_dataframe() + primary_key_values = list(zip(*(temp[key] for key in primary_keys))) + return primary_keys, primary_key_values + def _execute_insert(self, conn, keys: list[str], data_iter): """ Execute SQL statement inserting data @@ -845,22 +1047,36 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter): stmt = insert(self.table).values(data) conn.execute(stmt) - def insert_data(self): + def _get_index_formatted_dataframe(self): + """ + Format index of incoming dataframe to be aligned with a database table. + + Copy original dataframe, and check whether the dataframe index + is to be added to the database table. + If it is, reset the index so that it becomes a normal column, else return + + Returns + ------- + DataFrame + """ + # Originally this functionality formed the first step of the insert_data method. + # It will be useful to have in other places, so moved here to keep code DRY. + temp = self.frame.copy() if self.index is not None: - temp = self.frame.copy() temp.index.names = self.index try: temp.reset_index(inplace=True) except ValueError as err: raise ValueError(f"duplicate name in index/columns: {err}") from err - else: - temp = self.frame - column_names = list(map(str, temp.columns)) + return temp + + @staticmethod + def insert_data(data): + column_names = list(map(str, data.columns)) ncols = len(column_names) data_list = [None] * ncols - - for i, (_, ser) in enumerate(temp.items()): + for i, (_, ser) in enumerate(data.items()): vals = ser._values if vals.dtype.kind == "M": d = vals.to_pydatetime() @@ -884,7 +1100,24 @@ def insert_data(self): return column_names, data_list def insert(self, chunksize: int | None = None, method: str | None = None): + """ + Determines what data to pass to the underlying insert method. + """ + if self.on_conflict == "do_update": + new_data, update_stmts = self._on_conflict_do_update() + self._insert( + data=new_data, + chunksize=chunksize, + method=method, + other_stmts=update_stmts, + ) + elif self.on_conflict == "do_nothing": + new_data = self._on_conflict_do_nothing() + self._insert(data=new_data, chunksize=chunksize, method=method) + else: + self._insert(chunksize=chunksize, method=method) + def _insert(self, data=None, chunksize=None, method=None, other_stmts=[]): # set insert method if method is None: exec_insert = self._execute_insert @@ -895,9 +1128,12 @@ def insert(self, chunksize: int | None = None, method: str | None = None): else: raise ValueError(f"Invalid parameter `method`: {method}") - keys, data_list = self.insert_data() + if data is None: + data = self._get_index_formatted_dataframe() - nrows = len(self.frame) + keys, data_list = self.insert_data(data=data) + + nrows = len(data) if nrows == 0: return @@ -910,6 +1146,10 @@ def insert(self, chunksize: int | None = None, method: str | None = None): chunks = (nrows // chunksize) + 1 with self.pd_sql.run_transaction() as conn: + if len(other_stmts) > 0: + for stmt in other_stmts: + conn.execute(stmt) + for i in range(chunks): start_i = i * chunksize end_i = min((i + 1) * chunksize, nrows) @@ -1234,6 +1474,7 @@ def to_sql( frame, name, if_exists="fail", + on_conflict=None, index=True, index_label=None, schema=None, @@ -1555,6 +1796,7 @@ def prep_table( frame, name, if_exists="fail", + on_conflict=None, index=True, index_label=None, schema=None, @@ -1590,6 +1832,7 @@ def prep_table( frame=frame, index=index, if_exists=if_exists, + on_conflict=on_conflict, index_label=index_label, schema=schema, dtype=dtype, @@ -1636,6 +1879,7 @@ def to_sql( frame, name, if_exists="fail", + on_conflict: str | None = None, index=True, index_label=None, schema=None, @@ -1653,10 +1897,18 @@ def to_sql( frame : DataFrame name : string Name of SQL table. - if_exists : {'fail', 'replace', 'append'}, default 'fail' + if_exists : {'fail', 'replace', 'append'}, + default 'fail'. - fail: If table exists, do nothing. - replace: If table exists, drop it, recreate it, and insert data. - append: If table exists, insert data. Create if does not exist. + on_conflict : {None, 'do_nothing', 'do_update'}, optional + Determine insertion behaviour in case of a primary key clash. + - None: Do nothing to handle primary key clashes, will raise an Error. + - 'do_nothing': Ignore incoming rows with primary key clashes, and + insert only the incoming rows with non-conflicting primary keys + - 'do_update': Update existing rows in database with primary key clashes, + and append the remaining rows with non-conflicting primary keys index : boolean, default True Write DataFrame index as a column. index_label : string or sequence, default None @@ -1699,6 +1951,7 @@ def to_sql( frame=frame, name=name, if_exists=if_exists, + on_conflict=on_conflict, index=index, index_label=index_label, schema=schema, @@ -2094,6 +2347,7 @@ def to_sql( frame, name, if_exists="fail", + on_conflict: str | None = None, index=True, index_label=None, schema=None, @@ -2114,6 +2368,13 @@ def to_sql( fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if it does not exist. + on_conflict : {None, 'do_nothing', 'do_update'}, optional + Determine insertion behaviour in case of a primary key clash. + - None: Do nothing to handle primary key clashes, will raise an Error. + - 'do_nothing': Ignore incoming rows with primary key clashes, and + insert only the incoming rows with non-conflicting primary keys + - 'do_update': Update existing rows in database with primary key clashes, + and append the remaining rows with non-conflicting primary keys index : bool, default True Write DataFrame index as a column index_label : string or sequence, default None diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index eb3097618e158..de2267236b84a 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -93,9 +93,115 @@ "mysql": "SELECT * FROM iris WHERE `Name` LIKE '%'", "postgresql": "SELECT * FROM iris WHERE \"Name\" LIKE '%'", }, + "read_pkey_table": { + "pkey_table_single": { + "sqlite": """SELECT c FROM pkey_table_single WHERE A IN (?, ?)""", + "mysql": """SELECT c FROM pkey_table_single WHERE A IN (%s, %s)""", + "postgresql": """SELECT c FROM pkey_table_single WHERE A IN (%s, %s)""", + }, + "pkey_table_comp": { + "sqlite": """SELECT c FROM pkey_table_comp WHERE A IN (?, ?)""", + "mysql": """SELECT c FROM pkey_table_comp WHERE A IN (%s, %s)""", + "postgresql": """SELECT c FROM pkey_table_comp WHERE A IN (%s, %s)""", + }, + }, } +def pkey_single_table_metadata(): + from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + ) + + metadata = MetaData() + pkeys = Table( + "pkey_table_single", + metadata, + Column("a", Integer, primary_key=True), + Column("b", String(200)), + Column("c", String(200)), + ) + return pkeys + + +def pkey_comp_table_metadata(): + from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + ) + + metadata = MetaData() + pkeys = Table( + "pkey_table_comp", + metadata, + Column("a", Integer, primary_key=True), + Column("b", String(200), primary_key=True), + Column("c", String(200)), + ) + return pkeys + + +def create_and_load_pkey(conn): + from sqlalchemy import insert + from sqlalchemy.engine import Engine + + pkey_single = pkey_single_table_metadata() + pkey_comp = pkey_comp_table_metadata() + + pkey_single.drop(conn, checkfirst=True) + pkey_single.create(bind=conn) + pkey_comp.drop(conn, checkfirst=True) + pkey_comp.create(bind=conn) + + headers = ["a", "b", "c"] + data = [(1, "name1", "val1"), (2, "name2", "val2"), (3, "name3", "val3")] + params = [{key: value for key, value in zip(headers, row)} for row in data] + + stmt_single = insert(pkey_single).values(params) + stmt_comp = insert(pkey_comp).values(params) + + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + conn.execute(stmt_single) + conn.execute(stmt_comp) + else: + conn.execute(stmt_single) + conn.execute(stmt_comp) + + +def create_and_load_pkey_sqlite3(conn: sqlite3.Connection): + cur = conn.cursor() + stmt_single = """ + CREATE TABLE pkey_table_single ( + "a" Primary Key, + "b" TEXT, + "c" TEXT + ) + """ + stmt_comp = """ + CREATE TABLE pkey_table_comp ( + "a" Integer, + "b" TEXT, + "c" TEXT, + PRIMARY KEY ("a", "b") + ) + """ + cur.execute(stmt_single) + cur.execute(stmt_comp) + data = [(1, "name1", "val1"), (2, "name2", "val2"), (3, "name3", "val3")] + for tbl in ["pkey_table_single", "pkey_table_comp"]: + stmt = f"INSERT INTO {tbl} VALUES (?, ?, ?)" + cur.executemany(stmt, data) + + def iris_table_metadata(dialect: str): from sqlalchemy import ( REAL, @@ -266,6 +372,31 @@ def count_rows(conn, table_name: str): return result.fetchone()[0] +def read_pkeys_from_database(conn, tbl_name: str, duplicate_keys: list[int]): + if isinstance(conn, sqlite3.Connection): + stmt = f"""SELECT c FROM {tbl_name} WHERE A IN (?, ?)""" + cur = conn.cursor() + result = cur.execute(stmt, duplicate_keys) + else: + from sqlalchemy import ( + MetaData, + Table, + select, + ) + from sqlalchemy.engine import Engine + + meta = MetaData() + tbl = Table(tbl_name, meta, autoload_with=conn) + stmt = select([tbl.c.c]).where(tbl.c.a.in_(duplicate_keys)) + + if isinstance(conn, Engine): + with conn.connect() as conn: + result = conn.execute(stmt) + else: + result = conn.execute(stmt) + return sorted(val[0] for val in result.fetchall()) + + @pytest.fixture def iris_path(datapath): iris_path = datapath("io", "data", "csv", "iris.csv") @@ -367,6 +498,18 @@ def test_frame3(): return DataFrame(data, columns=columns) +@pytest.fixture +def pkey_frame(): + columns = ["a", "b", "c"] + data = [ + (1, "name1", "new_val1"), + (2, "name2", "new_val2"), + (4, "name4", "val4"), + (5, "name5", "val5"), + ] + return DataFrame(data, columns=columns) + + class MixInBase: def teardown_method(self, method): # if setup fails, there may not be a connection to close. @@ -454,6 +597,17 @@ def load_types_data(self, types_data): else: create_and_load_types(self.conn, types_data, self.flavor) + @pytest.fixture + def load_pkey_data(self): + if not hasattr(self, "conn"): + self.setup_connect() + self.drop_table("pkey_table_single") + self.drop_table("pkey_table_comp") + if isinstance(self.conn, sqlite3.Connection): + create_and_load_pkey_sqlite3(self.conn) + else: + create_and_load_pkey(self.conn) + def _check_iris_loaded_frame(self, iris_frame): pytype = iris_frame.dtypes[0].type row = iris_frame.iloc[0] @@ -561,6 +715,173 @@ def sample(pd_table, conn, keys, data_iter): # Nuke table self.drop_table("test_frame1") + def _to_sql_on_conflict_update(self, method, tbl_name, pkey_frame): + """ + GIVEN: + - Original database table: 3 rows + - new dataframe: 4 rows (2 duplicate keys) + WHEN: + - on conflict update insert + THEN: + - DB table len = 5 + - Conflicting primary keys in DB updated + """ + # Original table exists and as 3 rows + assert self.pandasSQL.has_table(tbl_name) + assert count_rows(self.conn, tbl_name) == 3 + # Insert new dataframe + self.pandasSQL.to_sql( + pkey_frame, + tbl_name, + if_exists="append", + on_conflict="do_update", + index=False, + method=method, + ) + # Check table len correct + assert count_rows(self.conn, tbl_name) == 5 + # Check conflicting primary keys have been updated + # Get new values for conflicting keys + data_from_db = read_pkeys_from_database(self.conn, tbl_name, [1, 2]) + # duplicate_keys = [1, 2] + # duplicate_key_query = SQL_STRINGS["read_pkey_table"][self.flavor] + # duplicate_val = self._get_exec().execute(duplicate_key_query, duplicate_keys) + # data_from_db = sorted(val[0] for val in duplicate_val) + + # Expected values from pkey_table_frame + expected = sorted(["new_val1", "new_val2"]) + assert data_from_db == expected + # Finally, confirm that duplicate values are not removed from original df object + assert len(pkey_frame.index) == 4 + # Clean up + self.drop_table(tbl_name) + + def _to_sql_on_conflict_nothing(self, method, tbl_name, pkey_frame): + """ + GIVEN: + - Original table: 3 rows + - new dataframe: 4 rows (2 duplicate keys) + WHEN: + - on conflict do nothing insert + THEN: + - database table len = 5 + - conflicting keys in table not updated + """ + # Original table exists and has 3 rows + assert self.pandasSQL.has_table(tbl_name) + assert count_rows(self.conn, tbl_name) == 3 + # Prepare SQL for reading duplicate keys + # duplicate_keys = [1, 2] + # duplicate_key_query = SQL_STRINGS["read_pkey_table"][self.flavor] + #  get conflicting pkey values before insert + # duplicate_val_before = self._get_exec().execute( + # duplicate_key_query, duplicate_keys + # ) + # data_from_db_before = sorted(val[0] for val in duplicate_val_before) + duplicate_keys = [1, 2] + data_from_db_before = read_pkeys_from_database( + self.conn, tbl_name, duplicate_keys + ) + # Insert new dataframe + self.pandasSQL.to_sql( + pkey_frame, + tbl_name, + if_exists="append", + on_conflict="do_nothing", + index=False, + method=method, + ) + # Check table len correct + assert count_rows(self.conn, tbl_name) == 5 + # Get conflicting keys from DB after to_sql + # duplicate_val_after = self._get_exec().execute( + # duplicate_key_query, duplicate_keys + # ) + # data_from_db_after = sorted(val[0] for val in duplicate_val_after) + data_from_db_after = read_pkeys_from_database( + self.conn, tbl_name, duplicate_keys + ) + # Get data from incoming df + data_from_df = sorted( + pkey_frame.loc[pkey_frame["a"].isin(duplicate_keys), "c"].tolist() + ) + # Check original DB values maintained for duplicate keys + assert data_from_db_before == data_from_db_after + # Check DB values not equal to new values + assert data_from_db_after != data_from_df + # Clean up + self.drop_table(tbl_name) + + def _test_to_sql_on_conflict_with_index(self, method, tbl_name, pkey_frame): + """ + GIVEN: + - Original db table: 3 rows + - New dataframe: 4 rows (2 duplicate keys), pkey as index + WHEN: + - inserting new data, noting the index column + - on conflict do update + THEN: + - DB table len = 5 + - Conflicting primary keys in DB updated + """ + # Original table exists and as 3 rows + assert self.pandasSQL.has_table(tbl_name) + assert count_rows(self.conn, tbl_name) == 3 + if tbl_name == "pkey_table_single": + index_pkey_table = pkey_frame.set_index("a") + else: + index_pkey_table = pkey_frame.set_index(["a", "b"]) + # Insert new dataframe + self.pandasSQL.to_sql( + index_pkey_table, + tbl_name, + if_exists="append", + on_conflict="do_update", + index=True, + method=method, + ) + # Check table len correct + assert count_rows(self.conn, tbl_name) == 5 + # Check conflicting primary keys have been updated + # Get new values for conflicting keys + # duplicate_keys = [1, 2] + # duplicate_key_query = SQL_STRINGS["read_pkey_table"][self.flavor] + # duplicate_val = self._get_exec().execute(duplicate_key_query, duplicate_keys) + # data_from_db = sorted(val[0] for val in duplicate_val) + data_from_db = read_pkeys_from_database(self.conn, tbl_name, [1, 2]) + # Expected values from pkey_table_frame + expected = sorted(["new_val1", "new_val2"]) + assert data_from_db == expected + # Finally, confirm that duplicate values are not removed from original df object + assert len(pkey_frame.index) == 4 + # Clean up + self.drop_table(tbl_name) + + def _to_sql_on_conflict_with_non_append(self, if_exists, on_conflict, pkey_frame): + """ + GIVEN: + - to_sql is called + WHEN: + - `on_conflict` is not null + - `if_exists` is set to a value other than `append` + THEN: + - ValueError is raised + """ + # Attempt insert + assert if_exists != "append" + with pytest.raises( + ValueError, match="on_conflict can only be used with 'append' operations" + ): + # Insert new dataframe + sql.to_sql( + pkey_frame, + "some_table", + con=self.conn, + if_exists=if_exists, + on_conflict=on_conflict, + index=False, + ) + def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): """`to_sql` with the `engine` param""" # mostly copied from this class's `_to_sql()` method @@ -669,7 +990,7 @@ def setup_connect(self): self.conn = self.connect() @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, load_iris_data, load_types_data, load_pkey_data): self.load_test_data_and_sql() def load_test_data_and_sql(self): @@ -736,6 +1057,28 @@ def test_to_sql_series(self): s2 = sql.read_sql_query("SELECT * FROM test_series", self.conn) tm.assert_frame_equal(s.to_frame(), s2) + def test_to_sql_invalid_on_conflict(self, pkey_frame): + msg = "'update' is not valid for on_conflict" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + pkey_frame, + "pkey_frame1", + self.conn, + if_exists="append", + on_conflict="update", + ) + + def test_to_sql_on_conflict_non_append(self, pkey_frame): + msg = "on_conflict can only be used with 'append' operations" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + pkey_frame, + "pkey_frame1", + self.conn, + if_exists="replace", + on_conflict="do_update", + ) + def test_roundtrip(self, test_frame1): sql.to_sql(test_frame1, "test_frame_roundtrip", con=self.conn) result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=self.conn) @@ -1300,7 +1643,7 @@ class _EngineToConnMixin: """ @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, load_iris_data, load_types_data, load_pkey_data): super().load_test_data_and_sql() engine = self.conn conn = engine.connect() @@ -1423,7 +1766,7 @@ def load_test_data_and_sql(self): pass @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, load_iris_data, load_types_data, load_pkey_data): pass @classmethod @@ -1479,6 +1822,26 @@ def test_to_sql_method_multi(self, test_frame1): def test_to_sql_method_callable(self, test_frame1): self._to_sql_method_callable(test_frame1) + @pytest.mark.parametrize("method", [None, "multi"]) + @pytest.mark.parametrize("tbl_name", ["pkey_table_single", "pkey_table_comp"]) + def test_to_sql_conflict_nothing(self, method, tbl_name, pkey_frame): + self._to_sql_on_conflict_nothing(method, tbl_name, pkey_frame) + + @pytest.mark.parametrize("method", [None, "multi"]) + @pytest.mark.parametrize("tbl_name", ["pkey_table_single", "pkey_table_comp"]) + def test_to_sql_conflict_update(self, method, tbl_name, pkey_frame): + self._to_sql_on_conflict_update(method, tbl_name, pkey_frame) + + @pytest.mark.parametrize("method", [None, "multi"]) + @pytest.mark.parametrize("tbl_name", ["pkey_table_single", "pkey_table_comp"]) + def test_to_sql_on_conflict_with_index(self, method, tbl_name, pkey_frame): + self._test_to_sql_on_conflict_with_index(method, tbl_name, pkey_frame) + + @pytest.mark.parametrize("if_exists", ["fail", "replace"]) + @pytest.mark.parametrize("on_conflict", ["do_update", "do_nothing"]) + def test_to_sql_conflict_with_non_append(self, if_exists, on_conflict, pkey_frame): + self._to_sql_on_conflict_with_non_append(if_exists, on_conflict, pkey_frame) + def test_create_table(self): temp_conn = self.connect() temp_frame = DataFrame(