Skip to content

Commit a2f68ef

Browse files
Ildar Almakaevjaidisido
Ildar Almakaev
andauthored
Use DataFrame column names in INSERT statement for UPSERT operation (#1317)
* Use DataFrame column names in INSERT statement for UPSERT operation * Make column_names as an optional parameter * Fix input parameter type to Optional[List[str]] * Fix Pylint issue. Replace unnecessary for-comprehension to list(df.columns) * Reformat code Co-authored-by: jaidisido <[email protected]>
1 parent 29b94d1 commit a2f68ef

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

awswrangler/redshift.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _upsert(
174174
schema: str,
175175
primary_keys: Optional[List[str]] = None,
176176
precombine_key: Optional[str] = None,
177+
column_names: Optional[List[str]] = None,
177178
) -> None:
178179
if not primary_keys:
179180
primary_keys = _get_primary_keys(cursor=cursor, schema=schema, table=table)
@@ -199,7 +200,11 @@ def _upsert(
199200
sql: str = f'DELETE FROM "{schema}"."{table}" USING {temp_table} WHERE {join_clause}'
200201
_logger.debug(sql)
201202
cursor.execute(sql)
202-
insert_sql = f"INSERT INTO {schema}.{table} SELECT * FROM {temp_table}"
203+
if column_names:
204+
column_names_str = ",".join(column_names)
205+
insert_sql = f"INSERT INTO {schema}.{table}({column_names_str}) SELECT {column_names_str} FROM {temp_table}"
206+
else:
207+
insert_sql = f"INSERT INTO {schema}.{table} SELECT * FROM {temp_table}"
203208
_logger.debug(insert_sql)
204209
cursor.execute(insert_sql)
205210
_drop_table(cursor=cursor, schema=schema, table=temp_table)
@@ -903,11 +908,12 @@ def to_sql( # pylint: disable=too-many-locals
903908
)
904909
if index:
905910
df.reset_index(level=df.index.names, inplace=True)
906-
column_placeholders: str = ", ".join(["%s"] * len(df.columns))
911+
column_names = list(df.columns)
912+
column_placeholders: str = ", ".join(["%s"] * len(column_names))
907913
schema_str = f'"{created_schema}".' if created_schema else ""
908914
insertion_columns = ""
909915
if use_column_names:
910-
insertion_columns = f"({', '.join(df.columns)})"
916+
insertion_columns = f"({', '.join(column_names)})"
911917
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
912918
df=df, column_placeholders=column_placeholders, chunksize=chunksize
913919
)
@@ -923,6 +929,7 @@ def to_sql( # pylint: disable=too-many-locals
923929
temp_table=created_table,
924930
primary_keys=primary_keys,
925931
precombine_key=precombine_key,
932+
column_names=column_names,
926933
)
927934
if commit_transaction:
928935
con.commit()

0 commit comments

Comments
 (0)