Skip to content

Commit 88ccb25

Browse files
simonjayhawkinsWillAyd
authored andcommitted
add some type annotations io/formats/format.py (#27418)
1 parent 8f4295a commit 88ccb25

File tree

2 files changed

+65
-47
lines changed

2 files changed

+65
-47
lines changed

pandas/io/formats/format.py

+56-38
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from functools import partial
77
from io import StringIO
88
from shutil import get_terminal_size
9+
from typing import TYPE_CHECKING, List, Optional, TextIO, Tuple, Union, cast
910
from unicodedata import east_asian_width
1011

1112
import numpy as np
@@ -47,6 +48,9 @@
4748
from pandas.io.common import _expand_user, _stringify_path
4849
from pandas.io.formats.printing import adjoin, justify, pprint_thing
4950

51+
if TYPE_CHECKING:
52+
from pandas import Series, DataFrame, Categorical
53+
5054
common_docstring = """
5155
Parameters
5256
----------
@@ -127,14 +131,21 @@
127131

128132

129133
class CategoricalFormatter:
130-
def __init__(self, categorical, buf=None, length=True, na_rep="NaN", footer=True):
134+
def __init__(
135+
self,
136+
categorical: "Categorical",
137+
buf: Optional[TextIO] = None,
138+
length: bool = True,
139+
na_rep: str = "NaN",
140+
footer: bool = True,
141+
):
131142
self.categorical = categorical
132143
self.buf = buf if buf is not None else StringIO("")
133144
self.na_rep = na_rep
134145
self.length = length
135146
self.footer = footer
136147

137-
def _get_footer(self):
148+
def _get_footer(self) -> str:
138149
footer = ""
139150

140151
if self.length:
@@ -151,15 +162,15 @@ def _get_footer(self):
151162

152163
return str(footer)
153164

154-
def _get_formatted_values(self):
165+
def _get_formatted_values(self) -> List[str]:
155166
return format_array(
156167
self.categorical._internal_get_values(),
157168
None,
158169
float_format=None,
159170
na_rep=self.na_rep,
160171
)
161172

162-
def to_string(self):
173+
def to_string(self) -> str:
163174
categorical = self.categorical
164175

165176
if len(categorical) == 0:
@@ -170,10 +181,10 @@ def to_string(self):
170181

171182
fmt_values = self._get_formatted_values()
172183

173-
result = ["{i}".format(i=i) for i in fmt_values]
174-
result = [i.strip() for i in result]
175-
result = ", ".join(result)
176-
result = ["[" + result + "]"]
184+
fmt_values = ["{i}".format(i=i) for i in fmt_values]
185+
fmt_values = [i.strip() for i in fmt_values]
186+
values = ", ".join(fmt_values)
187+
result = ["[" + values + "]"]
177188
if self.footer:
178189
footer = self._get_footer()
179190
if footer:
@@ -185,17 +196,17 @@ def to_string(self):
185196
class SeriesFormatter:
186197
def __init__(
187198
self,
188-
series,
189-
buf=None,
190-
length=True,
191-
header=True,
192-
index=True,
193-
na_rep="NaN",
194-
name=False,
195-
float_format=None,
196-
dtype=True,
197-
max_rows=None,
198-
min_rows=None,
199+
series: "Series",
200+
buf: Optional[TextIO] = None,
201+
length: bool = True,
202+
header: bool = True,
203+
index: bool = True,
204+
na_rep: str = "NaN",
205+
name: bool = False,
206+
float_format: Optional[str] = None,
207+
dtype: bool = True,
208+
max_rows: Optional[int] = None,
209+
min_rows: Optional[int] = None,
199210
):
200211
self.series = series
201212
self.buf = buf if buf is not None else StringIO()
@@ -215,7 +226,7 @@ def __init__(
215226

216227
self._chk_truncate()
217228

218-
def _chk_truncate(self):
229+
def _chk_truncate(self) -> None:
219230
from pandas.core.reshape.concat import concat
220231

221232
min_rows = self.min_rows
@@ -225,6 +236,7 @@ def _chk_truncate(self):
225236
truncate_v = max_rows and (len(self.series) > max_rows)
226237
series = self.series
227238
if truncate_v:
239+
max_rows = cast(int, max_rows)
228240
if min_rows:
229241
# if min_rows is set (not None or 0), set max_rows to minimum
230242
# of both
@@ -235,13 +247,13 @@ def _chk_truncate(self):
235247
else:
236248
row_num = max_rows // 2
237249
series = concat((series.iloc[:row_num], series.iloc[-row_num:]))
238-
self.tr_row_num = row_num
250+
self.tr_row_num = row_num # type: Optional[int]
239251
else:
240252
self.tr_row_num = None
241253
self.tr_series = series
242254
self.truncate_v = truncate_v
243255

244-
def _get_footer(self):
256+
def _get_footer(self) -> str:
245257
name = self.series.name
246258
footer = ""
247259

@@ -279,7 +291,7 @@ def _get_footer(self):
279291

280292
return str(footer)
281293

282-
def _get_formatted_index(self):
294+
def _get_formatted_index(self) -> Tuple[List[str], bool]:
283295
index = self.tr_series.index
284296
is_multi = isinstance(index, ABCMultiIndex)
285297

@@ -291,13 +303,13 @@ def _get_formatted_index(self):
291303
fmt_index = index.format(name=True)
292304
return fmt_index, have_header
293305

294-
def _get_formatted_values(self):
306+
def _get_formatted_values(self) -> List[str]:
295307
values_to_format = self.tr_series._formatting_values()
296308
return format_array(
297309
values_to_format, None, float_format=self.float_format, na_rep=self.na_rep
298310
)
299311

300-
def to_string(self):
312+
def to_string(self) -> str:
301313
series = self.tr_series
302314
footer = self._get_footer()
303315

@@ -312,6 +324,7 @@ def to_string(self):
312324
if self.truncate_v:
313325
n_header_rows = 0
314326
row_num = self.tr_row_num
327+
row_num = cast(int, row_num)
315328
width = self.adj.len(fmt_values[row_num - 1])
316329
if width > 3:
317330
dot_str = "..."
@@ -499,7 +512,7 @@ def __init__(
499512
self._chk_truncate()
500513
self.adj = _get_adjustment()
501514

502-
def _chk_truncate(self):
515+
def _chk_truncate(self) -> None:
503516
"""
504517
Checks whether the frame should be truncated. If so, slices
505518
the frame up.
@@ -575,7 +588,7 @@ def _chk_truncate(self):
575588
self.truncate_v = truncate_v
576589
self.is_truncated = self.truncate_h or self.truncate_v
577590

578-
def _to_str_columns(self):
591+
def _to_str_columns(self) -> List[List[str]]:
579592
"""
580593
Render a DataFrame to a list of columns (as lists of strings).
581594
"""
@@ -665,7 +678,7 @@ def _to_str_columns(self):
665678
strcols[ix].insert(row_num + n_header_rows, dot_str)
666679
return strcols
667680

668-
def to_string(self):
681+
def to_string(self) -> None:
669682
"""
670683
Render a DataFrame to a console-friendly tabular output.
671684
"""
@@ -801,7 +814,7 @@ def to_latex(
801814
else:
802815
raise TypeError("buf is not a file name and it has no write " "method")
803816

804-
def _format_col(self, i):
817+
def _format_col(self, i: int) -> List[str]:
805818
frame = self.tr_frame
806819
formatter = self._get_formatter(i)
807820
values_to_format = frame.iloc[:, i]._formatting_values()
@@ -814,7 +827,12 @@ def _format_col(self, i):
814827
decimal=self.decimal,
815828
)
816829

817-
def to_html(self, classes=None, notebook=False, border=None):
830+
def to_html(
831+
self,
832+
classes: Optional[Union[str, List, Tuple]] = None,
833+
notebook: bool = False,
834+
border: Optional[int] = None,
835+
) -> None:
818836
"""
819837
Render a DataFrame to a html table.
820838
@@ -843,7 +861,7 @@ def to_html(self, classes=None, notebook=False, border=None):
843861
else:
844862
raise TypeError("buf is not a file name and it has no write " " method")
845863

846-
def _get_formatted_column_labels(self, frame):
864+
def _get_formatted_column_labels(self, frame: "DataFrame") -> List[List[str]]:
847865
from pandas.core.index import _sparsify
848866

849867
columns = frame.columns
@@ -885,22 +903,22 @@ def space_format(x, y):
885903
return str_columns
886904

887905
@property
888-
def has_index_names(self):
906+
def has_index_names(self) -> bool:
889907
return _has_names(self.frame.index)
890908

891909
@property
892-
def has_column_names(self):
910+
def has_column_names(self) -> bool:
893911
return _has_names(self.frame.columns)
894912

895913
@property
896-
def show_row_idx_names(self):
914+
def show_row_idx_names(self) -> bool:
897915
return all((self.has_index_names, self.index, self.show_index_names))
898916

899917
@property
900-
def show_col_idx_names(self):
918+
def show_col_idx_names(self) -> bool:
901919
return all((self.has_column_names, self.show_index_names, self.header))
902920

903-
def _get_formatted_index(self, frame):
921+
def _get_formatted_index(self, frame: "DataFrame") -> List[str]:
904922
# Note: this is only used by to_string() and to_latex(), not by
905923
# to_html().
906924
index = frame.index
@@ -939,8 +957,8 @@ def _get_formatted_index(self, frame):
939957
else:
940958
return adjoined
941959

942-
def _get_column_name_list(self):
943-
names = []
960+
def _get_column_name_list(self) -> List[str]:
961+
names = [] # type: List[str]
944962
columns = self.frame.columns
945963
if isinstance(columns, ABCMultiIndex):
946964
names.extend("" if name is None else name for name in columns.names)

pandas/io/formats/html.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
from collections import OrderedDict
66
from textwrap import dedent
7-
from typing import Dict, List, Optional, Tuple, Union
7+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
88

99
from pandas._config import get_option
1010

11-
from pandas.core.dtypes.generic import ABCIndex, ABCMultiIndex
11+
from pandas.core.dtypes.generic import ABCMultiIndex
1212

1313
from pandas import option_context
1414

@@ -37,7 +37,7 @@ def __init__(
3737
self,
3838
formatter: DataFrameFormatter,
3939
classes: Optional[Union[str, List, Tuple]] = None,
40-
border: Optional[bool] = None,
40+
border: Optional[int] = None,
4141
) -> None:
4242
self.fmt = formatter
4343
self.classes = classes
@@ -79,7 +79,7 @@ def row_levels(self) -> int:
7979
# not showing (row) index
8080
return 0
8181

82-
def _get_columns_formatted_values(self) -> ABCIndex:
82+
def _get_columns_formatted_values(self) -> Iterable:
8383
return self.columns
8484

8585
@property
@@ -90,12 +90,12 @@ def is_truncated(self) -> bool:
9090
def ncols(self) -> int:
9191
return len(self.fmt.tr_frame.columns)
9292

93-
def write(self, s: str, indent: int = 0) -> None:
93+
def write(self, s: Any, indent: int = 0) -> None:
9494
rs = pprint_thing(s)
9595
self.elements.append(" " * indent + rs)
9696

9797
def write_th(
98-
self, s: str, header: bool = False, indent: int = 0, tags: Optional[str] = None
98+
self, s: Any, header: bool = False, indent: int = 0, tags: Optional[str] = None
9999
) -> None:
100100
"""
101101
Method for writting a formatted <th> cell.
@@ -125,11 +125,11 @@ def write_th(
125125

126126
self._write_cell(s, kind="th", indent=indent, tags=tags)
127127

128-
def write_td(self, s: str, indent: int = 0, tags: Optional[str] = None) -> None:
128+
def write_td(self, s: Any, indent: int = 0, tags: Optional[str] = None) -> None:
129129
self._write_cell(s, kind="td", indent=indent, tags=tags)
130130

131131
def _write_cell(
132-
self, s: str, kind: str = "td", indent: int = 0, tags: Optional[str] = None
132+
self, s: Any, kind: str = "td", indent: int = 0, tags: Optional[str] = None
133133
) -> None:
134134
if tags is not None:
135135
start_tag = "<{kind} {tags}>".format(kind=kind, tags=tags)
@@ -162,7 +162,7 @@ def _write_cell(
162162

163163
def write_tr(
164164
self,
165-
line: List[str],
165+
line: Iterable,
166166
indent: int = 0,
167167
indent_delta: int = 0,
168168
header: bool = False,

0 commit comments

Comments
 (0)