Skip to content

Fix extension dtype index handling #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def pyarrow2pandas_extension( # pylint: disable=too-many-branches,too-many-retu
return None


def pyarrow_types_from_pandas(
def pyarrow_types_from_pandas( # pylint: disable=too-many-branches
df: pd.DataFrame, index: bool, ignore_cols: Optional[List[str]] = None, index_left: bool = False
) -> Dict[str, pa.DataType]:
"""Extract the related Pyarrow data types from any Pandas DataFrame."""
Expand Down Expand Up @@ -469,7 +469,19 @@ def pyarrow_types_from_pandas(
# Filling indexes
indexes: List[str] = []
if index is True:
for field in pa.Schema.from_pandas(df=df[[]], preserve_index=True):
# Get index columns
try:
fields = pa.Schema.from_pandas(df=df[[]], preserve_index=True)
except AttributeError as ae:
if "'Index' object has no attribute 'head'" not in str(ae):
raise ae
# Get index fields from a new df with only index columns
# Adding indexes as columns via .reset_index() because
# pa.Schema.from_pandas(.., preserve_index=True) fails with
# "'Index' object has no attribute 'head'" if using extension
# dtypes on pandas 1.4.x
fields = pa.Schema.from_pandas(df=df.reset_index().drop(columns=cols), preserve_index=False)
for field in fields:
name = str(field.name)
_logger.debug("Inferring PyArrow type from index: %s", name)
cols_dtypes[name] = field.type
Expand Down