Skip to content

Commit 7787f0a

Browse files
author
Brannon Imamura
authored
Feature: Add precombine key to upsert method for Redshift (#1304)
* implement precombine_key for upserts This will prefer data from the file / tmp table when the precombine keys are equal. Also fix up some inconsistencies in the docs. * black formatting * Update redshift.py * add test for precombine upsert * Update test_redshift.py * Update test_redshift.py * sort imports... * no index setting * order and data type must be equal for comparison to work * Iterator[DataFrame] vs DataFrame * pandas is being tricky somewhere, switching to numpy comparison
1 parent 425b969 commit 7787f0a

File tree

2 files changed

+130
-36
lines changed

2 files changed

+130
-36
lines changed

awswrangler/redshift.py

Lines changed: 75 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def _upsert(
173173
temp_table: str,
174174
schema: str,
175175
primary_keys: Optional[List[str]] = None,
176+
precombine_key: Optional[str] = None,
176177
) -> None:
177178
if not primary_keys:
178179
primary_keys = _get_primary_keys(cursor=cursor, schema=schema, table=table)
@@ -181,12 +182,26 @@ def _upsert(
181182
raise exceptions.InvalidRedshiftPrimaryKeys()
182183
equals_clause: str = f"{table}.%s = {temp_table}.%s"
183184
join_clause: str = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
184-
sql: str = f'DELETE FROM "{schema}"."{table}" USING {temp_table} WHERE {join_clause}'
185-
_logger.debug(sql)
186-
cursor.execute(sql)
187-
sql = f"INSERT INTO {schema}.{table} SELECT * FROM {temp_table}"
188-
_logger.debug(sql)
189-
cursor.execute(sql)
185+
if precombine_key:
186+
delete_from_target_filter: str = f"AND {table}.{precombine_key} <= {temp_table}.{precombine_key}"
187+
delete_from_temp_filter: str = f"AND {table}.{precombine_key} > {temp_table}.{precombine_key}"
188+
target_del_sql: str = (
189+
f'DELETE FROM "{schema}"."{table}" USING {temp_table} WHERE {join_clause} {delete_from_target_filter}'
190+
)
191+
_logger.debug(target_del_sql)
192+
cursor.execute(target_del_sql)
193+
source_del_sql: str = (
194+
f'DELETE FROM {temp_table} USING "{schema}"."{table}" WHERE {join_clause} {delete_from_temp_filter}'
195+
)
196+
_logger.debug(source_del_sql)
197+
cursor.execute(source_del_sql)
198+
else:
199+
sql: str = f'DELETE FROM "{schema}"."{table}" USING {temp_table} WHERE {join_clause}'
200+
_logger.debug(sql)
201+
cursor.execute(sql)
202+
insert_sql = f"INSERT INTO {schema}.{table} SELECT * FROM {temp_table}"
203+
_logger.debug(insert_sql)
204+
cursor.execute(insert_sql)
190205
_drop_table(cursor=cursor, schema=schema, table=temp_table)
191206

192207

@@ -424,29 +439,29 @@ def connect(
424439
----------
425440
connection : Optional[str]
426441
Glue Catalog Connection name.
427-
secret_id: Optional[str]:
442+
secret_id : Optional[str]:
428443
Specifies the secret containing the connection details that you want to retrieve.
429444
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
430445
catalog_id : str, optional
431446
The ID of the Data Catalog.
432447
If none is provided, the AWS account ID is used by default.
433-
dbname: Optional[str]
448+
dbname : Optional[str]
434449
Optional database name to overwrite the stored one.
435450
boto3_session : boto3.Session(), optional
436451
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
437-
ssl: bool
452+
ssl : bool
438453
This governs SSL encryption for TCP/IP sockets.
439454
This parameter is forward to redshift_connector.
440455
https://github.com/aws/amazon-redshift-python-driver
441-
timeout: Optional[int]
456+
timeout : Optional[int]
442457
This is the time in seconds before the connection to the server will time out.
443458
The default is None which means no timeout.
444459
This parameter is forward to redshift_connector.
445460
https://github.com/aws/amazon-redshift-python-driver
446-
max_prepared_statements: int
461+
max_prepared_statements : int
447462
This parameter is forward to redshift_connector.
448463
https://github.com/aws/amazon-redshift-python-driver
449-
tcp_keepalive: bool
464+
tcp_keepalive : bool
450465
If True then use TCP keepalive. The default is True.
451466
This parameter is forward to redshift_connector.
452467
https://github.com/aws/amazon-redshift-python-driver
@@ -534,19 +549,19 @@ def connect_temp(
534549
in addition to any group memberships for an existing user. If not specified, a new user is added only to PUBLIC.
535550
boto3_session : boto3.Session(), optional
536551
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
537-
ssl: bool
552+
ssl : bool
538553
This governs SSL encryption for TCP/IP sockets.
539554
This parameter is forward to redshift_connector.
540555
https://github.com/aws/amazon-redshift-python-driver
541-
timeout: Optional[int]
556+
timeout : Optional[int]
542557
This is the time in seconds before the connection to the server will time out.
543558
The default is None which means no timeout.
544559
This parameter is forward to redshift_connector.
545560
https://github.com/aws/amazon-redshift-python-driver
546-
max_prepared_statements: int
561+
max_prepared_statements : int
547562
This parameter is forward to redshift_connector.
548563
https://github.com/aws/amazon-redshift-python-driver
549-
tcp_keepalive: bool
564+
tcp_keepalive : bool
550565
If True then use TCP keepalive. The default is True.
551566
This parameter is forward to redshift_connector.
552567
https://github.com/aws/amazon-redshift-python-driver
@@ -697,7 +712,7 @@ def read_sql_table(
697712
List of parameters to pass to execute method.
698713
The syntax used to pass parameters is database driver dependent.
699714
Check your database driver documentation for which of the five syntax styles,
700-
described in PEP 249s paramstyle, is supported.
715+
described in PEP 249's paramstyle, is supported.
701716
chunksize : int, optional
702717
If specified, return an iterator where chunksize is the number of rows to include in each chunk.
703718
dtype : Dict[str, pyarrow.DataType], optional
@@ -761,6 +776,7 @@ def to_sql( # pylint: disable=too-many-locals
761776
lock: bool = False,
762777
chunksize: int = 200,
763778
commit_transaction: bool = True,
779+
precombine_key: Optional[str] = None,
764780
) -> None:
765781
"""Write records stored in a DataFrame into Redshift.
766782
@@ -793,7 +809,7 @@ def to_sql( # pylint: disable=too-many-locals
793809
index : bool
794810
True to store the DataFrame index as a column in the table,
795811
otherwise False to ignore it.
796-
dtype: Dict[str, str], optional
812+
dtype : Dict[str, str], optional
797813
Dictionary of columns names and Redshift types to be casted.
798814
Useful when you have columns with undetermined or mixed data types.
799815
(e.g. {'col name': 'VARCHAR(10)', 'col2 name': 'FLOAT'})
@@ -819,10 +835,14 @@ def to_sql( # pylint: disable=too-many-locals
819835
inserted into the database columns `col1` and `col3`.
820836
lock : bool
821837
True to execute LOCK command inside the transaction to force serializable isolation.
822-
chunksize: int
838+
chunksize : int
823839
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
824-
commit_transaction: bool
840+
commit_transaction : bool
825841
Whether to commit the transaction. True by default.
842+
precombine_key : str, optional
843+
When there is a primary_key match during upsert, this column will change the upsert method,
844+
comparing the values of the specified column from source and target, and keeping the
845+
larger of the two. Will only work when mode = upsert.
826846
827847
Returns
828848
-------
@@ -887,7 +907,14 @@ def to_sql( # pylint: disable=too-many-locals
887907
if table != created_table: # upsert
888908
if lock:
889909
_lock(cursor, [table], schema=schema)
890-
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
910+
_upsert(
911+
cursor=cursor,
912+
schema=schema,
913+
table=table,
914+
temp_table=created_table,
915+
primary_keys=primary_keys,
916+
precombine_key=precombine_key,
917+
)
891918
if commit_transaction:
892919
con.commit()
893920
except Exception as ex:
@@ -1071,7 +1098,7 @@ def unload(
10711098
10721099
Parameters
10731100
----------
1074-
sql: str
1101+
sql : str
10751102
SQL query.
10761103
path : Union[str, List[str]]
10771104
S3 path to write stage files (e.g. s3://bucket_name/any_name/)
@@ -1114,7 +1141,7 @@ def unload(
11141141
If integer is provided, specified number is used.
11151142
boto3_session : boto3.Session(), optional
11161143
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1117-
s3_additional_kwargs:
1144+
s3_additional_kwargs : Dict[str, str], optional
11181145
Forward to botocore requests, only "SSECustomerAlgorithm" and "SSECustomerKey" arguments will be considered.
11191146
11201147
Returns
@@ -1206,6 +1233,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
12061233
sql_copy_extra_params: Optional[List[str]] = None,
12071234
boto3_session: Optional[boto3.Session] = None,
12081235
s3_additional_kwargs: Optional[Dict[str, str]] = None,
1236+
precombine_key: Optional[str] = None,
12091237
) -> None:
12101238
"""Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).
12111239
@@ -1277,12 +1305,12 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
12771305
Should Wrangler add SERIALIZETOJSON parameter into the COPY command?
12781306
SERIALIZETOJSON is necessary to load nested data
12791307
https://docs.aws.amazon.com/redshift/latest/dg/ingest-super.html#copy_json
1280-
path_suffix: Union[str, List[str], None]
1308+
path_suffix : Union[str, List[str], None]
12811309
Suffix or List of suffixes to be scanned on s3 for the schema extraction
12821310
(e.g. [".gz.parquet", ".snappy.parquet"]).
12831311
Only has effect during the table creation.
12841312
If None, will try to read all files. (default)
1285-
path_ignore_suffix: Union[str, List[str], None]
1313+
path_ignore_suffix : Union[str, List[str], None]
12861314
Suffix or List of suffixes for S3 keys to be ignored during the schema extraction.
12871315
(e.g. [".csv", "_SUCCESS"]).
12881316
Only has effect during the table creation.
@@ -1293,17 +1321,21 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
12931321
If integer is provided, specified number is used.
12941322
lock : bool
12951323
True to execute LOCK command inside the transaction to force serializable isolation.
1296-
commit_transaction: bool
1324+
commit_transaction : bool
12971325
Whether to commit the transaction. True by default.
1298-
manifest: bool
1326+
manifest : bool
12991327
If set to true path argument accepts a S3 uri to a manifest file.
1300-
sql_copy_extra_params: Optional[List[str]]
1328+
sql_copy_extra_params : Optional[List[str]]
13011329
Additional copy parameters to pass to the command. For example: ["STATUPDATE ON"]
13021330
boto3_session : boto3.Session(), optional
13031331
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1304-
s3_additional_kwargs:
1332+
s3_additional_kwargs : Dict[str, str], optional
13051333
Forwarded to botocore requests.
13061334
e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMS_KEY_ARN'}
1335+
precombine_key : str, optional
1336+
When there is a primary_key match during upsert, this column will change the upsert method,
1337+
comparing the values of the specified column from source and target, and keeping the
1338+
larger of the two. Will only work when mode = upsert.
13071339
13081340
Returns
13091341
-------
@@ -1374,7 +1406,14 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
13741406
if table != created_table: # upsert
13751407
if lock:
13761408
_lock(cursor, [table], schema=schema)
1377-
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
1409+
_upsert(
1410+
cursor=cursor,
1411+
schema=schema,
1412+
table=table,
1413+
temp_table=created_table,
1414+
primary_keys=primary_keys,
1415+
precombine_key=precombine_key,
1416+
)
13781417
if commit_transaction:
13791418
con.commit()
13801419
except Exception as ex:
@@ -1440,7 +1479,7 @@ def copy( # pylint: disable=too-many-arguments
14401479
14411480
Parameters
14421481
----------
1443-
df: pandas.DataFrame
1482+
df : pandas.DataFrame
14441483
Pandas DataFrame.
14451484
path : str
14461485
S3 path to write stage files (e.g. s3://bucket_name/any_name/).
@@ -1462,12 +1501,12 @@ def copy( # pylint: disable=too-many-arguments
14621501
The session key for your AWS account. This is only needed when you are using temporary credentials.
14631502
index : bool
14641503
True to store the DataFrame index in file, otherwise False to ignore it.
1465-
dtype: Dict[str, str], optional
1504+
dtype : Dict[str, str], optional
14661505
Dictionary of columns names and Athena/Glue types to be casted.
14671506
Useful when you have columns with undetermined or mixed data types.
14681507
Only takes effect if dataset=True.
14691508
(e.g. {'col name': 'bigint', 'col2 name': 'int'})
1470-
mode: str
1509+
mode : str
14711510
Append, overwrite or upsert.
14721511
overwrite_method : str
14731512
Drop, cascade, truncate, or delete. Only applicable in overwrite mode.
@@ -1477,7 +1516,7 @@ def copy( # pylint: disable=too-many-arguments
14771516
"truncate" - ``TRUNCATE ...`` - truncates the table, but immediatly commits current
14781517
transaction & starts a new one, hence the overwrite happens in two transactions and is not atomic.
14791518
"delete" - ``DELETE FROM ...`` - deletes all rows from the table. Slow relative to the other methods.
1480-
diststyle: str
1519+
diststyle : str
14811520
Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
14821521
https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
14831522
distkey : str, optional
@@ -1501,11 +1540,11 @@ def copy( # pylint: disable=too-many-arguments
15011540
If integer is provided, specified number is used.
15021541
lock : bool
15031542
True to execute LOCK command inside the transaction to force serializable isolation.
1504-
sql_copy_extra_params: Optional[List[str]]
1543+
sql_copy_extra_params : Optional[List[str]]
15051544
Additional copy parameters to pass to the command. For example: ["STATUPDATE ON"]
15061545
boto3_session : boto3.Session(), optional
15071546
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
1508-
s3_additional_kwargs:
1547+
s3_additional_kwargs : Dict[str, str], optional
15091548
Forwarded to botocore requests.
15101549
e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMS_KEY_ARN'}
15111550
max_rows_by_file : int

tests/test_redshift.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from decimal import Decimal
66

77
import boto3
8+
import numpy as np
89
import pandas as pd
910
import pyarrow as pa
1011
import pytest
@@ -697,6 +698,60 @@ def test_upsert(redshift_table, redshift_con):
697698
assert len(df.columns) == len(df4.columns)
698699

699700

701+
def test_upsert_precombine(redshift_table, redshift_con):
702+
df = pd.DataFrame({"id": list((range(10))), "val": list([1.0 if i % 2 == 0 else 10.0 for i in range(10)])})
703+
df3 = pd.DataFrame({"id": list((range(6, 14))), "val": list([10.0 if i % 2 == 0 else 1.0 for i in range(8)])})
704+
705+
# Do upsert in pandas
706+
df_m = pd.merge(df, df3, on="id", how="outer")
707+
df_m["val"] = np.where(df_m["val_y"] >= df_m["val_x"], df_m["val_y"], df_m["val_x"])
708+
df_m["val"] = df_m["val"].fillna(df_m["val_y"])
709+
df_m = df_m.drop(columns=["val_x", "val_y"])
710+
711+
# CREATE
712+
wr.redshift.to_sql(
713+
df=df,
714+
con=redshift_con,
715+
schema="public",
716+
table=redshift_table,
717+
mode="overwrite",
718+
index=False,
719+
primary_keys=["id"],
720+
)
721+
df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} order by id", con=redshift_con)
722+
assert df.shape == df2.shape
723+
724+
# UPSERT
725+
wr.redshift.to_sql(
726+
df=df3,
727+
con=redshift_con,
728+
schema="public",
729+
table=redshift_table,
730+
mode="upsert",
731+
index=False,
732+
primary_keys=["id"],
733+
precombine_key="val",
734+
)
735+
df4 = wr.redshift.read_sql_query(
736+
sql=f"SELECT * FROM public.{redshift_table} order by id",
737+
con=redshift_con,
738+
)
739+
assert np.array_equal(df_m.to_numpy(), df4.to_numpy())
740+
741+
# UPSERT 2
742+
wr.redshift.to_sql(
743+
df=df3,
744+
con=redshift_con,
745+
schema="public",
746+
table=redshift_table,
747+
mode="upsert",
748+
index=False,
749+
precombine_key="val",
750+
)
751+
df4 = wr.redshift.read_sql_query(sql=f"SELECT * FROM public.{redshift_table} order by id", con=redshift_con)
752+
assert np.array_equal(df_m.to_numpy(), df4.to_numpy())
753+
754+
700755
def test_read_retry(redshift_con):
701756
try:
702757
wr.redshift.read_sql_query("ERROR", redshift_con)

0 commit comments

Comments
 (0)