Skip to content

Commit b2e6e22

Browse files
committed
Improve decimal casting.
1 parent 8dd0376 commit b2e6e22

File tree

2 files changed

+62
-18
lines changed

2 files changed

+62
-18
lines changed

awswrangler/_data_types.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def athena_types_from_pyarrow_schema(
417417

418418
def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd.DataFrame:
419419
"""Cast columns in a Pandas DataFrame."""
420-
mutable_ensured: bool = False
420+
mutability_ensured: bool = False
421421
for col, athena_type in dtype.items():
422422
if (
423423
(col in df.columns)
@@ -429,9 +429,9 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
429429
current_type: str = _normalize_pandas_dtype_name(dtype=str(df[col].dtypes))
430430
if desired_type != current_type: # Needs conversion
431431
_logger.debug("current_type: %s -> desired_type: %s", current_type, desired_type)
432-
if mutable_ensured is False:
432+
if mutability_ensured is False:
433433
df = _utils.ensure_df_is_mutable(df=df)
434-
mutable_ensured = True
434+
mutability_ensured = True
435435
_cast_pandas_column(df=df, col=col, current_type=current_type, desired_type=desired_type)
436436

437437
return df
@@ -453,11 +453,10 @@ def _cast_pandas_column(df: pd.DataFrame, col: str, current_type: str, desired_t
453453
elif desired_type == "bytes":
454454
df[col] = df[col].astype("string").str.encode(encoding="utf-8").replace(to_replace={pd.NA: None})
455455
elif desired_type == "decimal":
456-
df[col] = (
457-
df[col]
458-
.astype("str")
459-
.apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
460-
)
456+
# First cast to string
457+
df = _cast_pandas_column(df=df, col=col, current_type=current_type, desired_type="string")
458+
# Then cast to decimal
459+
df[col] = df[col].apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "<NA>") else None)
461460
elif desired_type == "string":
462461
if current_type.lower().startswith("int") is True:
463462
df[col] = df[col].astype(str).astype("string")

tests/test_db.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import random
33
import string
4+
from decimal import Decimal
45

56
import boto3
67
import pandas as pd
@@ -114,8 +115,7 @@ def test_postgresql_param():
114115
assert df["col0"].iloc[0] == 1
115116

116117

117-
def test_redshift_copy_unload(bucket, databases_parameters):
118-
path = f"s3://{bucket}/test_redshift_copy/"
118+
def test_redshift_copy_unload(path, databases_parameters):
119119
df = get_df().drop(["iint8", "binary"], axis=1, inplace=False)
120120
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
121121
wr.db.copy_to_redshift(
@@ -258,10 +258,9 @@ def test_redshift_copy_upsert(bucket, databases_parameters):
258258
(None, None, wr.exceptions.InvalidRedshiftSortstyle, "foo", ["id"]),
259259
],
260260
)
261-
def test_redshift_exceptions(bucket, databases_parameters, diststyle, distkey, sortstyle, sortkey, exc):
261+
def test_redshift_exceptions(path, databases_parameters, diststyle, distkey, sortstyle, sortkey, exc):
262262
df = pd.DataFrame({"id": [1], "name": "joe"})
263263
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
264-
path = f"s3://{bucket}/test_redshift_exceptions_{random.randint(0, 1_000_000)}/"
265264
with pytest.raises(exc):
266265
wr.db.copy_to_redshift(
267266
df=df,
@@ -280,9 +279,8 @@ def test_redshift_exceptions(bucket, databases_parameters, diststyle, distkey, s
280279
wr.s3.delete_objects(path=path)
281280

282281

283-
def test_redshift_spectrum(bucket, glue_database, redshift_external_schema):
282+
def test_redshift_spectrum(path, glue_database, redshift_external_schema):
284283
df = pd.DataFrame({"id": [1, 2, 3, 4, 5], "col_str": ["foo", None, "bar", None, "xoo"], "par_int": [0, 1, 0, 1, 1]})
285-
path = f"s3://{bucket}/test_redshift_spectrum/"
286284
paths = wr.s3.to_parquet(
287285
df=df,
288286
path=path,
@@ -305,8 +303,7 @@ def test_redshift_spectrum(bucket, glue_database, redshift_external_schema):
305303
assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_redshift_spectrum") is True
306304

307305

308-
def test_redshift_category(bucket, databases_parameters):
309-
path = f"s3://{bucket}/test_redshift_category/"
306+
def test_redshift_category(path, databases_parameters):
310307
df = get_df_category().drop(["binary"], axis=1, inplace=False)
311308
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
312309
wr.db.copy_to_redshift(
@@ -341,10 +338,9 @@ def test_redshift_category(bucket, databases_parameters):
341338
wr.s3.delete_objects(path=path)
342339

343340

344-
def test_redshift_unload_extras(bucket, databases_parameters, kms_key_id):
341+
def test_redshift_unload_extras(bucket, path, databases_parameters, kms_key_id):
345342
table = "test_redshift_unload_extras"
346343
schema = databases_parameters["redshift"]["schema"]
347-
path = f"s3://{bucket}/{table}/"
348344
wr.s3.delete_objects(path=path)
349345
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
350346
df = pd.DataFrame({"id": [1, 2], "name": ["foo", "boo"]})
@@ -529,3 +525,52 @@ def test_redshift_copy_unload_long_string(path, databases_parameters):
529525
)
530526
assert len(df2.index) == 2
531527
assert len(df2.columns) == 2
528+
529+
530+
def test_spectrum_decimal_cast(path, path2, glue_table, glue_database, redshift_external_schema, databases_parameters):
531+
df = pd.DataFrame(
532+
{"c0": [1, 2], "c1": [1, None], "c2": [2.22222, None], "c3": ["3.33333", None], "c4": [None, None]}
533+
)
534+
paths = wr.s3.to_parquet(
535+
df=df,
536+
path=path,
537+
database=glue_database,
538+
table=glue_table,
539+
dataset=True,
540+
dtype={"c1": "decimal(11,5)", "c2": "decimal(11,5)", "c3": "decimal(11,5)", "c4": "decimal(11,5)"},
541+
)["paths"]
542+
wr.s3.wait_objects_exist(paths=paths, use_threads=False)
543+
544+
# Athena
545+
df2 = wr.athena.read_sql_table(table=glue_table, database=glue_database)
546+
assert df2.shape == (2, 5)
547+
df2 = df2.drop(df2[df2.c0 == 2].index)
548+
assert df2.c1[0] == Decimal((0, (1, 0, 0, 0, 0, 0), -5))
549+
assert df2.c2[0] == Decimal((0, (2, 2, 2, 2, 2, 2), -5))
550+
assert df2.c3[0] == Decimal((0, (3, 3, 3, 3, 3, 3), -5))
551+
assert df2.c4[0] is None
552+
553+
# Redshift Spectrum
554+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
555+
df2 = wr.db.read_sql_table(table=glue_table, schema=redshift_external_schema, con=engine)
556+
assert df2.shape == (2, 5)
557+
df2 = df2.drop(df2[df2.c0 == 2].index)
558+
assert df2.c1[0] == Decimal((0, (1, 0, 0, 0, 0, 0), -5))
559+
assert df2.c2[0] == Decimal((0, (2, 2, 2, 2, 2, 2), -5))
560+
assert df2.c3[0] == Decimal((0, (3, 3, 3, 3, 3, 3), -5))
561+
assert df2.c4[0] is None
562+
563+
# Redshift Spectrum Unload
564+
engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift")
565+
df2 = wr.db.unload_redshift(
566+
sql=f"SELECT * FROM {redshift_external_schema}.{glue_table}",
567+
con=engine,
568+
iam_role=databases_parameters["redshift"]["role"],
569+
path=path2,
570+
)
571+
assert df2.shape == (2, 5)
572+
df2 = df2.drop(df2[df2.c0 == 2].index)
573+
assert df2.c1[0] == Decimal((0, (1, 0, 0, 0, 0, 0), -5))
574+
assert df2.c2[0] == Decimal((0, (2, 2, 2, 2, 2, 2), -5))
575+
assert df2.c3[0] == Decimal((0, (3, 3, 3, 3, 3, 3), -5))
576+
assert df2.c4[0] is None

0 commit comments

Comments
 (0)