Skip to content

Commit 2306e4f

Browse files
fix: Schema evolution for to_csv and to_json (#2104)
1 parent 3c5f236 commit 2306e4f

File tree

3 files changed

+167
-19
lines changed

3 files changed

+167
-19
lines changed

awswrangler/catalog/_create.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,29 @@ def _overwrite_table_parameters(
252252
return parameters
253253

254254

255+
def _update_table_input(table_input: Dict[str, Any], columns_types: Dict[str, str], allow_reorder: bool = True) -> bool:
256+
column_updated = False
257+
258+
catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
259+
260+
if not allow_reorder:
261+
for catalog_key, frame_key in zip(catalog_cols, columns_types):
262+
if catalog_key != frame_key:
263+
raise exceptions.InvalidArgumentValue(f"Column {frame_key} is out of order.")
264+
265+
for c, t in columns_types.items():
266+
if c not in catalog_cols:
267+
_logger.debug("New column %s with type %s.", c, t)
268+
table_input["StorageDescriptor"]["Columns"].append({"Name": c, "Type": t})
269+
column_updated = True
270+
elif t != catalog_cols[c]: # Data type change detected!
271+
raise exceptions.InvalidArgumentValue(
272+
f"Data type change detected on column {c} (Old type: {catalog_cols[c]} / New type {t})."
273+
)
274+
275+
return column_updated
276+
277+
255278
def _create_parquet_table(
256279
database: str,
257280
table: str,
@@ -282,19 +305,14 @@ def _create_parquet_table(
282305
table = sanitize_table_name(table=table)
283306
partitions_types = {} if partitions_types is None else partitions_types
284307
_logger.debug("catalog_table_input: %s", catalog_table_input)
308+
285309
table_input: Dict[str, Any]
286310
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
287311
table_input = catalog_table_input
288-
catalog_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]}
289-
for c, t in columns_types.items():
290-
if c not in catalog_cols:
291-
_logger.debug("New column %s with type %s.", c, t)
292-
table_input["StorageDescriptor"]["Columns"].append({"Name": c, "Type": t})
293-
mode = "update"
294-
elif t != catalog_cols[c]: # Data type change detected!
295-
raise exceptions.InvalidArgumentValue(
296-
f"Data type change detected on column {c} (Old type: {catalog_cols[c]} / New type {t})."
297-
)
312+
313+
is_table_updated = _update_table_input(table_input, columns_types)
314+
if is_table_updated:
315+
mode = "update"
298316
else:
299317
table_input = _parquet_table_definition(
300318
table=table,
@@ -368,11 +386,18 @@ def _create_csv_table( # pylint: disable=too-many-arguments,too-many-locals
368386
table = sanitize_table_name(table=table)
369387
partitions_types = {} if partitions_types is None else partitions_types
370388
_logger.debug("catalog_table_input: %s", catalog_table_input)
371-
table_input: Dict[str, Any]
389+
372390
if schema_evolution is False:
373391
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
392+
393+
table_input: Dict[str, Any]
374394
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
375395
table_input = catalog_table_input
396+
397+
is_table_updated = _update_table_input(table_input, columns_types, allow_reorder=False)
398+
if is_table_updated:
399+
mode = "update"
400+
376401
else:
377402
table_input = _csv_table_definition(
378403
table=table,
@@ -415,7 +440,7 @@ def _create_csv_table( # pylint: disable=too-many-arguments,too-many-locals
415440
)
416441

417442

418-
def _create_json_table( # pylint: disable=too-many-arguments
443+
def _create_json_table( # pylint: disable=too-many-arguments,too-many-locals
419444
database: str,
420445
table: str,
421446
path: str,
@@ -453,6 +478,11 @@ def _create_json_table( # pylint: disable=too-many-arguments
453478
_utils.check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode)
454479
if (catalog_table_input is not None) and (mode in ("append", "overwrite_partitions")):
455480
table_input = catalog_table_input
481+
482+
is_table_updated = _update_table_input(table_input, columns_types)
483+
if is_table_updated:
484+
mode = "update"
485+
456486
else:
457487
table_input = _json_table_definition(
458488
table=table,

tests/test_s3_parquet.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22
import logging
33
import math
4-
from datetime import datetime, timedelta, timezone
4+
from datetime import date, datetime, timedelta, timezone
55

66
import boto3
77
import numpy as np
@@ -571,6 +571,65 @@ def test_read_parquet_versioned(path) -> None:
571571
assert version_id == wr.s3.describe_objects(path=path_file, version_id=version_id)[path_file]["VersionId"]
572572

573573

574+
def test_parquet_schema_evolution(path, glue_database, glue_table):
575+
df = pd.DataFrame(
576+
{
577+
"id": [1, 2],
578+
"value": ["foo", "boo"],
579+
}
580+
)
581+
wr.s3.to_parquet(
582+
df=df,
583+
path=path,
584+
dataset=True,
585+
mode="overwrite",
586+
database=glue_database,
587+
table=glue_table,
588+
)
589+
590+
df2 = pd.DataFrame(
591+
{"id": [3, 4], "value": ["bar", None], "date": [date(2020, 1, 3), date(2020, 1, 4)], "flag": [True, False]}
592+
)
593+
wr.s3.to_parquet(
594+
df=df2,
595+
path=path,
596+
dataset=True,
597+
mode="append",
598+
database=glue_database,
599+
table=glue_table,
600+
schema_evolution=True,
601+
catalog_versioning=True,
602+
)
603+
604+
column_types = wr.catalog.get_table_types(glue_database, glue_table)
605+
assert len(column_types) == len(df2.columns)
606+
607+
608+
def test_to_parquet_schema_evolution_out_of_order(path, glue_database, glue_table) -> None:
609+
df = pd.DataFrame({"c0": [0, 1, 2], "c1": ["a", "b", "c"]})
610+
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)
611+
612+
df2 = df.copy()
613+
df2["c2"] = ["x", "y", "z"]
614+
615+
wr.s3.to_parquet(
616+
df=df2,
617+
path=path,
618+
dataset=True,
619+
database=glue_database,
620+
table=glue_table,
621+
mode="append",
622+
schema_evolution=True,
623+
catalog_versioning=True,
624+
)
625+
626+
df_out = wr.s3.read_parquet(path=path, dataset=True)
627+
df_expected = pd.concat([df, df2], ignore_index=True)
628+
629+
assert len(df_out) == len(df_expected)
630+
assert list(df_out.columns) == list(df_expected.columns)
631+
632+
574633
def test_read_parquet_schema_validation_with_index_column(path) -> None:
575634
path_file = f"{path}file.parquet"
576635
df = pd.DataFrame({"idx": [1], "col": [2]})

tests/test_s3_text.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,27 +359,86 @@ def test_read_csv_versioned(path) -> None:
359359
assert version_id == wr.s3.describe_objects(path=path_file, version_id=version_id)[path_file]["VersionId"]
360360

361361

362-
def test_to_csv_schema_evolution(path, glue_database, glue_table) -> None:
363-
path_file = f"{path}0.csv"
362+
@pytest.mark.parametrize("mode", ["append", "overwrite"])
363+
def test_to_csv_schema_evolution(path, glue_database, glue_table, mode) -> None:
364364
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
365-
wr.s3.to_csv(df=df, path=path_file, dataset=True, database=glue_database, table=glue_table)
365+
wr.s3.to_csv(df=df, path=path, dataset=True, database=glue_database, table=glue_table, index=False)
366+
366367
df["c2"] = [6, 7, 8]
367368
wr.s3.to_csv(
368369
df=df,
369-
path=path_file,
370+
path=path,
370371
dataset=True,
371372
database=glue_database,
372373
table=glue_table,
373-
mode="overwrite",
374+
mode=mode,
374375
schema_evolution=True,
376+
index=False,
375377
)
378+
379+
column_types = wr.catalog.get_table_types(glue_database, glue_table)
380+
assert len(column_types) == len(df.columns)
381+
376382
df["c3"] = [9, 10, 11]
383+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
384+
wr.s3.to_csv(df=df, path=path, dataset=True, database=glue_database, table=glue_table, schema_evolution=False)
385+
386+
387+
@pytest.mark.parametrize("schema_evolution", [False, True])
388+
def test_to_csv_schema_evolution_out_of_order(path, glue_database, glue_table, schema_evolution) -> None:
389+
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
390+
wr.s3.to_csv(df=df, path=path, dataset=True, database=glue_database, table=glue_table, index=False)
391+
392+
df["c2"] = [6, 7, 8]
393+
df = df[["c0", "c2", "c1"]]
394+
377395
with pytest.raises(wr.exceptions.InvalidArgumentValue):
378396
wr.s3.to_csv(
379-
df=df, path=path_file, dataset=True, database=glue_database, table=glue_table, schema_evolution=False
397+
df=df,
398+
path=path,
399+
dataset=True,
400+
database=glue_database,
401+
table=glue_table,
402+
mode="append",
403+
schema_evolution=schema_evolution,
404+
index=False,
380405
)
381406

382407

408+
@pytest.mark.parametrize("mode", ["append", "overwrite"])
409+
def test_to_json_schema_evolution(path, glue_database, glue_table, mode) -> None:
410+
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5]})
411+
wr.s3.to_json(
412+
df=df,
413+
path=path,
414+
dataset=True,
415+
database=glue_database,
416+
table=glue_table,
417+
orient="split",
418+
index=False,
419+
)
420+
421+
df["c2"] = [6, 7, 8]
422+
wr.s3.to_json(
423+
df=df,
424+
path=path,
425+
dataset=True,
426+
database=glue_database,
427+
table=glue_table,
428+
mode=mode,
429+
schema_evolution=True,
430+
orient="split",
431+
index=False,
432+
)
433+
434+
column_types = wr.catalog.get_table_types(glue_database, glue_table)
435+
assert len(column_types) == len(df.columns)
436+
437+
df["c3"] = [9, 10, 11]
438+
with pytest.raises(wr.exceptions.InvalidArgumentValue):
439+
wr.s3.to_json(df=df, path=path, dataset=True, database=glue_database, table=glue_table, schema_evolution=False)
440+
441+
383442
def test_exceptions(path):
384443
with pytest.raises(wr.exceptions.EmptyDataFrame):
385444
wr.s3.to_json(df=pd.DataFrame(), path=path)

0 commit comments

Comments
 (0)