6
6
from functools import partial
7
7
from io import StringIO
8
8
from shutil import get_terminal_size
9
+ from typing import TYPE_CHECKING , List , Optional , TextIO , Tuple , Union , cast
9
10
from unicodedata import east_asian_width
10
11
11
12
import numpy as np
47
48
from pandas .io .common import _expand_user , _stringify_path
48
49
from pandas .io .formats .printing import adjoin , justify , pprint_thing
49
50
51
+ if TYPE_CHECKING :
52
+ from pandas import Series , DataFrame , Categorical
53
+
50
54
common_docstring = """
51
55
Parameters
52
56
----------
127
131
128
132
129
133
class CategoricalFormatter :
130
- def __init__ (self , categorical , buf = None , length = True , na_rep = "NaN" , footer = True ):
134
+ def __init__ (
135
+ self ,
136
+ categorical : "Categorical" ,
137
+ buf : Optional [TextIO ] = None ,
138
+ length : bool = True ,
139
+ na_rep : str = "NaN" ,
140
+ footer : bool = True ,
141
+ ):
131
142
self .categorical = categorical
132
143
self .buf = buf if buf is not None else StringIO ("" )
133
144
self .na_rep = na_rep
134
145
self .length = length
135
146
self .footer = footer
136
147
137
- def _get_footer (self ):
148
+ def _get_footer (self ) -> str :
138
149
footer = ""
139
150
140
151
if self .length :
@@ -151,15 +162,15 @@ def _get_footer(self):
151
162
152
163
return str (footer )
153
164
154
- def _get_formatted_values (self ):
165
+ def _get_formatted_values (self ) -> List [ str ] :
155
166
return format_array (
156
167
self .categorical ._internal_get_values (),
157
168
None ,
158
169
float_format = None ,
159
170
na_rep = self .na_rep ,
160
171
)
161
172
162
- def to_string (self ):
173
+ def to_string (self ) -> str :
163
174
categorical = self .categorical
164
175
165
176
if len (categorical ) == 0 :
@@ -170,10 +181,10 @@ def to_string(self):
170
181
171
182
fmt_values = self ._get_formatted_values ()
172
183
173
- result = ["{i}" .format (i = i ) for i in fmt_values ]
174
- result = [i .strip () for i in result ]
175
- result = ", " .join (result )
176
- result = ["[" + result + "]" ]
184
+ fmt_values = ["{i}" .format (i = i ) for i in fmt_values ]
185
+ fmt_values = [i .strip () for i in fmt_values ]
186
+ values = ", " .join (fmt_values )
187
+ result = ["[" + values + "]" ]
177
188
if self .footer :
178
189
footer = self ._get_footer ()
179
190
if footer :
@@ -185,17 +196,17 @@ def to_string(self):
185
196
class SeriesFormatter :
186
197
def __init__ (
187
198
self ,
188
- series ,
189
- buf = None ,
190
- length = True ,
191
- header = True ,
192
- index = True ,
193
- na_rep = "NaN" ,
194
- name = False ,
195
- float_format = None ,
196
- dtype = True ,
197
- max_rows = None ,
198
- min_rows = None ,
199
+ series : "Series" ,
200
+ buf : Optional [ TextIO ] = None ,
201
+ length : bool = True ,
202
+ header : bool = True ,
203
+ index : bool = True ,
204
+ na_rep : str = "NaN" ,
205
+ name : bool = False ,
206
+ float_format : Optional [ str ] = None ,
207
+ dtype : bool = True ,
208
+ max_rows : Optional [ int ] = None ,
209
+ min_rows : Optional [ int ] = None ,
199
210
):
200
211
self .series = series
201
212
self .buf = buf if buf is not None else StringIO ()
@@ -215,7 +226,7 @@ def __init__(
215
226
216
227
self ._chk_truncate ()
217
228
218
- def _chk_truncate (self ):
229
+ def _chk_truncate (self ) -> None :
219
230
from pandas .core .reshape .concat import concat
220
231
221
232
min_rows = self .min_rows
@@ -225,6 +236,7 @@ def _chk_truncate(self):
225
236
truncate_v = max_rows and (len (self .series ) > max_rows )
226
237
series = self .series
227
238
if truncate_v :
239
+ max_rows = cast (int , max_rows )
228
240
if min_rows :
229
241
# if min_rows is set (not None or 0), set max_rows to minimum
230
242
# of both
@@ -235,13 +247,13 @@ def _chk_truncate(self):
235
247
else :
236
248
row_num = max_rows // 2
237
249
series = concat ((series .iloc [:row_num ], series .iloc [- row_num :]))
238
- self .tr_row_num = row_num
250
+ self .tr_row_num = row_num # type: Optional[int]
239
251
else :
240
252
self .tr_row_num = None
241
253
self .tr_series = series
242
254
self .truncate_v = truncate_v
243
255
244
- def _get_footer (self ):
256
+ def _get_footer (self ) -> str :
245
257
name = self .series .name
246
258
footer = ""
247
259
@@ -279,7 +291,7 @@ def _get_footer(self):
279
291
280
292
return str (footer )
281
293
282
- def _get_formatted_index (self ):
294
+ def _get_formatted_index (self ) -> Tuple [ List [ str ], bool ] :
283
295
index = self .tr_series .index
284
296
is_multi = isinstance (index , ABCMultiIndex )
285
297
@@ -291,13 +303,13 @@ def _get_formatted_index(self):
291
303
fmt_index = index .format (name = True )
292
304
return fmt_index , have_header
293
305
294
- def _get_formatted_values (self ):
306
+ def _get_formatted_values (self ) -> List [ str ] :
295
307
values_to_format = self .tr_series ._formatting_values ()
296
308
return format_array (
297
309
values_to_format , None , float_format = self .float_format , na_rep = self .na_rep
298
310
)
299
311
300
- def to_string (self ):
312
+ def to_string (self ) -> str :
301
313
series = self .tr_series
302
314
footer = self ._get_footer ()
303
315
@@ -312,6 +324,7 @@ def to_string(self):
312
324
if self .truncate_v :
313
325
n_header_rows = 0
314
326
row_num = self .tr_row_num
327
+ row_num = cast (int , row_num )
315
328
width = self .adj .len (fmt_values [row_num - 1 ])
316
329
if width > 3 :
317
330
dot_str = "..."
@@ -499,7 +512,7 @@ def __init__(
499
512
self ._chk_truncate ()
500
513
self .adj = _get_adjustment ()
501
514
502
- def _chk_truncate (self ):
515
+ def _chk_truncate (self ) -> None :
503
516
"""
504
517
Checks whether the frame should be truncated. If so, slices
505
518
the frame up.
@@ -575,7 +588,7 @@ def _chk_truncate(self):
575
588
self .truncate_v = truncate_v
576
589
self .is_truncated = self .truncate_h or self .truncate_v
577
590
578
- def _to_str_columns (self ):
591
+ def _to_str_columns (self ) -> List [ List [ str ]] :
579
592
"""
580
593
Render a DataFrame to a list of columns (as lists of strings).
581
594
"""
@@ -665,7 +678,7 @@ def _to_str_columns(self):
665
678
strcols [ix ].insert (row_num + n_header_rows , dot_str )
666
679
return strcols
667
680
668
- def to_string (self ):
681
+ def to_string (self ) -> None :
669
682
"""
670
683
Render a DataFrame to a console-friendly tabular output.
671
684
"""
@@ -801,7 +814,7 @@ def to_latex(
801
814
else :
802
815
raise TypeError ("buf is not a file name and it has no write " "method" )
803
816
804
- def _format_col (self , i ) :
817
+ def _format_col (self , i : int ) -> List [ str ] :
805
818
frame = self .tr_frame
806
819
formatter = self ._get_formatter (i )
807
820
values_to_format = frame .iloc [:, i ]._formatting_values ()
@@ -814,7 +827,12 @@ def _format_col(self, i):
814
827
decimal = self .decimal ,
815
828
)
816
829
817
- def to_html (self , classes = None , notebook = False , border = None ):
830
+ def to_html (
831
+ self ,
832
+ classes : Optional [Union [str , List , Tuple ]] = None ,
833
+ notebook : bool = False ,
834
+ border : Optional [int ] = None ,
835
+ ) -> None :
818
836
"""
819
837
Render a DataFrame to a html table.
820
838
@@ -843,7 +861,7 @@ def to_html(self, classes=None, notebook=False, border=None):
843
861
else :
844
862
raise TypeError ("buf is not a file name and it has no write " " method" )
845
863
846
- def _get_formatted_column_labels (self , frame ) :
864
+ def _get_formatted_column_labels (self , frame : "DataFrame" ) -> List [ List [ str ]] :
847
865
from pandas .core .index import _sparsify
848
866
849
867
columns = frame .columns
@@ -885,22 +903,22 @@ def space_format(x, y):
885
903
return str_columns
886
904
887
905
@property
888
- def has_index_names (self ):
906
+ def has_index_names (self ) -> bool :
889
907
return _has_names (self .frame .index )
890
908
891
909
@property
892
- def has_column_names (self ):
910
+ def has_column_names (self ) -> bool :
893
911
return _has_names (self .frame .columns )
894
912
895
913
@property
896
- def show_row_idx_names (self ):
914
+ def show_row_idx_names (self ) -> bool :
897
915
return all ((self .has_index_names , self .index , self .show_index_names ))
898
916
899
917
@property
900
- def show_col_idx_names (self ):
918
+ def show_col_idx_names (self ) -> bool :
901
919
return all ((self .has_column_names , self .show_index_names , self .header ))
902
920
903
- def _get_formatted_index (self , frame ) :
921
+ def _get_formatted_index (self , frame : "DataFrame" ) -> List [ str ] :
904
922
# Note: this is only used by to_string() and to_latex(), not by
905
923
# to_html().
906
924
index = frame .index
@@ -939,8 +957,8 @@ def _get_formatted_index(self, frame):
939
957
else :
940
958
return adjoined
941
959
942
- def _get_column_name_list (self ):
943
- names = []
960
+ def _get_column_name_list (self ) -> List [ str ] :
961
+ names = [] # type: List[str]
944
962
columns = self .frame .columns
945
963
if isinstance (columns , ABCMultiIndex ):
946
964
names .extend ("" if name is None else name for name in columns .names )
0 commit comments