diff --git a/pandas/io/formats/excel.py b/pandas/io/formats/excel.py index fe471c6f6f9ac..25885552239d6 100644 --- a/pandas/io/formats/excel.py +++ b/pandas/io/formats/excel.py @@ -5,16 +5,16 @@ from functools import reduce import itertools import re -from typing import Callable, Dict, Mapping, Optional, Sequence, Union +from typing import Callable, Dict, Iterable, Mapping, Optional, Sequence, Union, cast import warnings import numpy as np +from pandas._libs.lib import is_list_like from pandas._typing import Label, StorageOptions from pandas.core.dtypes import missing from pandas.core.dtypes.common import is_float, is_scalar -from pandas.core.dtypes.generic import ABCIndex from pandas import DataFrame, Index, MultiIndex, PeriodIndex import pandas.core.common as com @@ -29,7 +29,13 @@ class ExcelCell: __slots__ = __fields__ def __init__( - self, row: int, col: int, val, style=None, mergestart=None, mergeend=None + self, + row: int, + col: int, + val, + style=None, + mergestart: Optional[int] = None, + mergeend: Optional[int] = None, ): self.row = row self.col = col @@ -423,7 +429,7 @@ class ExcelFormatter: Format string for floating point numbers cols : sequence, optional Columns to write - header : boolean or list of string, default True + header : boolean or sequence of str, default True Write out column names. If a list of string is given it is assumed to be aliases for the column names index : boolean, default True @@ -522,7 +528,7 @@ def _format_value(self, val): ) return val - def _format_header_mi(self): + def _format_header_mi(self) -> Iterable[ExcelCell]: if self.columns.nlevels > 1: if not self.index: raise NotImplementedError( @@ -530,8 +536,7 @@ def _format_header_mi(self): "index ('index'=False) is not yet implemented." ) - has_aliases = isinstance(self.header, (tuple, list, np.ndarray, ABCIndex)) - if not (has_aliases or self.header): + if not (self._has_aliases or self.header): return columns = self.columns @@ -547,28 +552,30 @@ def _format_header_mi(self): if self.merge_cells: # Format multi-index as a merged cells. - for lnum in range(len(level_lengths)): - name = columns.names[lnum] - yield ExcelCell(lnum, coloffset, name, self.header_style) + for lnum, name in enumerate(columns.names): + yield ExcelCell( + row=lnum, + col=coloffset, + val=name, + style=self.header_style, + ) for lnum, (spans, levels, level_codes) in enumerate( zip(level_lengths, columns.levels, columns.codes) ): values = levels.take(level_codes) - for i in spans: - if spans[i] > 1: - yield ExcelCell( - lnum, - coloffset + i + 1, - values[i], - self.header_style, - lnum, - coloffset + i + spans[i], - ) - else: - yield ExcelCell( - lnum, coloffset + i + 1, values[i], self.header_style - ) + for i, span_val in spans.items(): + spans_multiple_cells = span_val > 1 + yield ExcelCell( + row=lnum, + col=coloffset + i + 1, + val=values[i], + style=self.header_style, + mergestart=lnum if spans_multiple_cells else None, + mergeend=( + coloffset + i + span_val if spans_multiple_cells else None + ), + ) else: # Format in legacy format with dots to indicate levels. for i, values in enumerate(zip(*level_strs)): @@ -577,9 +584,8 @@ def _format_header_mi(self): self.rowcounter = lnum - def _format_header_regular(self): - has_aliases = isinstance(self.header, (tuple, list, np.ndarray, ABCIndex)) - if has_aliases or self.header: + def _format_header_regular(self) -> Iterable[ExcelCell]: + if self._has_aliases or self.header: coloffset = 0 if self.index: @@ -588,17 +594,11 @@ def _format_header_regular(self): coloffset = len(self.df.index[0]) colnames = self.columns - if has_aliases: - # pandas\io\formats\excel.py:593: error: Argument 1 to "len" - # has incompatible type "Union[Sequence[Optional[Hashable]], - # bool]"; expected "Sized" [arg-type] - if len(self.header) != len(self.columns): # type: ignore[arg-type] - # pandas\io\formats\excel.py:602: error: Argument 1 to - # "len" has incompatible type - # "Union[Sequence[Optional[Hashable]], bool]"; expected - # "Sized" [arg-type] + if self._has_aliases: + self.header = cast(Sequence, self.header) + if len(self.header) != len(self.columns): raise ValueError( - f"Writing {len(self.columns)} cols " # type: ignore[arg-type] + f"Writing {len(self.columns)} cols " f"but got {len(self.header)} aliases" ) else: @@ -609,7 +609,7 @@ def _format_header_regular(self): self.rowcounter, colindex + coloffset, colname, self.header_style ) - def _format_header(self): + def _format_header(self) -> Iterable[ExcelCell]: if isinstance(self.columns, MultiIndex): gen = self._format_header_mi() else: @@ -631,15 +631,14 @@ def _format_header(self): self.rowcounter += 1 return itertools.chain(gen, gen2) - def _format_body(self): + def _format_body(self) -> Iterable[ExcelCell]: if isinstance(self.df.index, MultiIndex): return self._format_hierarchical_rows() else: return self._format_regular_rows() - def _format_regular_rows(self): - has_aliases = isinstance(self.header, (tuple, list, np.ndarray, ABCIndex)) - if has_aliases or self.header: + def _format_regular_rows(self) -> Iterable[ExcelCell]: + if self._has_aliases or self.header: self.rowcounter += 1 # output index and index_label? @@ -676,9 +675,8 @@ def _format_regular_rows(self): yield from self._generate_body(coloffset) - def _format_hierarchical_rows(self): - has_aliases = isinstance(self.header, (tuple, list, np.ndarray, ABCIndex)) - if has_aliases or self.header: + def _format_hierarchical_rows(self) -> Iterable[ExcelCell]: + if self._has_aliases or self.header: self.rowcounter += 1 gcolidx = 0 @@ -721,23 +719,20 @@ def _format_hierarchical_rows(self): fill_value=levels._na_value, ) - for i in spans: - if spans[i] > 1: - yield ExcelCell( - self.rowcounter + i, - gcolidx, - values[i], - self.header_style, - self.rowcounter + i + spans[i] - 1, - gcolidx, - ) - else: - yield ExcelCell( - self.rowcounter + i, - gcolidx, - values[i], - self.header_style, - ) + for i, span_val in spans.items(): + spans_multiple_cells = span_val > 1 + yield ExcelCell( + row=self.rowcounter + i, + col=gcolidx, + val=values[i], + style=self.header_style, + mergestart=( + self.rowcounter + i + span_val - 1 + if spans_multiple_cells + else None + ), + mergeend=gcolidx if spans_multiple_cells else None, + ) gcolidx += 1 else: @@ -745,16 +740,21 @@ def _format_hierarchical_rows(self): for indexcolvals in zip(*self.df.index): for idx, indexcolval in enumerate(indexcolvals): yield ExcelCell( - self.rowcounter + idx, - gcolidx, - indexcolval, - self.header_style, + row=self.rowcounter + idx, + col=gcolidx, + val=indexcolval, + style=self.header_style, ) gcolidx += 1 yield from self._generate_body(gcolidx) - def _generate_body(self, coloffset: int): + @property + def _has_aliases(self) -> bool: + """Whether the aliases for column names are present.""" + return is_list_like(self.header) + + def _generate_body(self, coloffset: int) -> Iterable[ExcelCell]: if self.styler is None: styles = None else: @@ -771,7 +771,7 @@ def _generate_body(self, coloffset: int): xlstyle = self.style_converter(";".join(styles[i, colidx])) yield ExcelCell(self.rowcounter + i, colidx + coloffset, val, xlstyle) - def get_formatted_cells(self): + def get_formatted_cells(self) -> Iterable[ExcelCell]: for cell in itertools.chain(self._format_header(), self._format_body()): cell.val = self._format_value(cell.val) yield cell