5
5
import csv as csvlib
6
6
from io import StringIO , TextIOWrapper
7
7
import os
8
- from typing import Any , Dict , Hashable , Iterator , List , Optional , Sequence , Union
8
+ from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Optional , Sequence , Union
9
9
10
10
import numpy as np
11
11
12
12
from pandas ._libs import writers as libwriters
13
13
from pandas ._typing import (
14
14
CompressionOptions ,
15
15
FilePathOrBuffer ,
16
+ FloatFormatType ,
16
17
IndexLabel ,
17
18
Label ,
18
19
StorageOptions ,
30
31
31
32
from pandas .io .common import get_filepath_or_buffer , get_handle
32
33
34
+ if TYPE_CHECKING :
35
+ from pandas .io .formats .format import DataFrameFormatter
36
+
33
37
34
38
class CSVFormatter :
35
39
def __init__ (
36
40
self ,
37
- obj ,
41
+ formatter : "DataFrameFormatter" ,
38
42
path_or_buf : Optional [FilePathOrBuffer [str ]] = None ,
39
43
sep : str = "," ,
40
- na_rep : str = "" ,
41
- float_format : Optional [str ] = None ,
42
44
cols : Optional [Sequence [Label ]] = None ,
43
- header : Union [bool , Sequence [Hashable ]] = True ,
44
- index : bool = True ,
45
45
index_label : Optional [IndexLabel ] = None ,
46
46
mode : str = "w" ,
47
47
encoding : Optional [str ] = None ,
@@ -54,10 +54,11 @@ def __init__(
54
54
date_format : Optional [str ] = None ,
55
55
doublequote : bool = True ,
56
56
escapechar : Optional [str ] = None ,
57
- decimal = "." ,
58
57
storage_options : StorageOptions = None ,
59
58
):
60
- self .obj = obj
59
+ self .fmt = formatter
60
+
61
+ self .obj = self .fmt .frame
61
62
62
63
self .encoding = encoding or "utf-8"
63
64
@@ -79,35 +80,45 @@ def __init__(
79
80
self .mode = ioargs .mode
80
81
81
82
self .sep = sep
82
- self .na_rep = na_rep
83
- self .float_format = float_format
84
- self .decimal = decimal
85
- self .header = header
86
- self .index = index
87
- self .index_label = index_label
83
+ self .index_label = self ._initialize_index_label (index_label )
88
84
self .errors = errors
89
85
self .quoting = quoting or csvlib .QUOTE_MINIMAL
90
- self .quotechar = quotechar
86
+ self .quotechar = self . _initialize_quotechar ( quotechar )
91
87
self .doublequote = doublequote
92
88
self .escapechar = escapechar
93
89
self .line_terminator = line_terminator or os .linesep
94
90
self .date_format = date_format
95
- self .cols = cols # type: ignore[assignment]
96
- self .chunksize = chunksize # type: ignore[assignment]
91
+ self .cols = self ._initialize_columns (cols )
92
+ self .chunksize = self ._initialize_chunksize (chunksize )
93
+
94
+ @property
95
+ def na_rep (self ) -> str :
96
+ return self .fmt .na_rep
97
+
98
+ @property
99
+ def float_format (self ) -> Optional ["FloatFormatType" ]:
100
+ return self .fmt .float_format
97
101
98
102
@property
99
- def index_label (self ) -> IndexLabel :
100
- return self ._index_label
103
+ def decimal (self ) -> str :
104
+ return self .fmt . decimal
101
105
102
- @index_label .setter
103
- def index_label (self , index_label : Optional [IndexLabel ]) -> None :
106
+ @property
107
+ def header (self ) -> Union [bool , Sequence [str ]]:
108
+ return self .fmt .header
109
+
110
+ @property
111
+ def index (self ) -> bool :
112
+ return self .fmt .index
113
+
114
+ def _initialize_index_label (self , index_label : Optional [IndexLabel ]) -> IndexLabel :
104
115
if index_label is not False :
105
116
if index_label is None :
106
- index_label = self ._get_index_label_from_obj ()
117
+ return self ._get_index_label_from_obj ()
107
118
elif not isinstance (index_label , (list , tuple , np .ndarray , ABCIndexClass )):
108
119
# given a string for a DF with Index
109
- index_label = [index_label ]
110
- self . _index_label = index_label
120
+ return [index_label ]
121
+ return index_label
111
122
112
123
def _get_index_label_from_obj (self ) -> List [str ]:
113
124
if isinstance (self .obj .index , ABCMultiIndex ):
@@ -122,30 +133,17 @@ def _get_index_label_flat(self) -> List[str]:
122
133
index_label = self .obj .index .name
123
134
return ["" ] if index_label is None else [index_label ]
124
135
125
- @property
126
- def quotechar (self ) -> Optional [str ]:
136
+ def _initialize_quotechar (self , quotechar : Optional [str ]) -> Optional [str ]:
127
137
if self .quoting != csvlib .QUOTE_NONE :
128
138
# prevents crash in _csv
129
- return self . _quotechar
139
+ return quotechar
130
140
return None
131
141
132
- @quotechar .setter
133
- def quotechar (self , quotechar : Optional [str ]) -> None :
134
- self ._quotechar = quotechar
135
-
136
142
@property
137
143
def has_mi_columns (self ) -> bool :
138
144
return bool (isinstance (self .obj .columns , ABCMultiIndex ))
139
145
140
- @property
141
- def cols (self ) -> Sequence [Label ]:
142
- return self ._cols
143
-
144
- @cols .setter
145
- def cols (self , cols : Optional [Sequence [Label ]]) -> None :
146
- self ._cols = self ._refine_cols (cols )
147
-
148
- def _refine_cols (self , cols : Optional [Sequence [Label ]]) -> Sequence [Label ]:
146
+ def _initialize_columns (self , cols : Optional [Sequence [Label ]]) -> Sequence [Label ]:
149
147
# validate mi options
150
148
if self .has_mi_columns :
151
149
if cols is not None :
@@ -161,12 +159,16 @@ def _refine_cols(self, cols: Optional[Sequence[Label]]) -> Sequence[Label]:
161
159
162
160
# update columns to include possible multiplicity of dupes
163
161
# and make sure sure cols is just a list of labels
164
- cols = self .obj .columns
165
- if isinstance (cols , ABCIndexClass ):
166
- return cols ._format_native_types (** self ._number_format )
162
+ new_cols = self .obj .columns
163
+ if isinstance (new_cols , ABCIndexClass ):
164
+ return new_cols ._format_native_types (** self ._number_format )
167
165
else :
168
- assert isinstance (cols , Sequence )
169
- return list (cols )
166
+ return list (new_cols )
167
+
168
+ def _initialize_chunksize (self , chunksize : Optional [int ]) -> int :
169
+ if chunksize is None :
170
+ return (100000 // (len (self .cols ) or 1 )) or 1
171
+ return int (chunksize )
170
172
171
173
@property
172
174
def _number_format (self ) -> Dict [str , Any ]:
@@ -179,17 +181,6 @@ def _number_format(self) -> Dict[str, Any]:
179
181
decimal = self .decimal ,
180
182
)
181
183
182
- @property
183
- def chunksize (self ) -> int :
184
- return self ._chunksize
185
-
186
- @chunksize .setter
187
- def chunksize (self , chunksize : Optional [int ]) -> None :
188
- if chunksize is None :
189
- chunksize = (100000 // (len (self .cols ) or 1 )) or 1
190
- assert chunksize is not None
191
- self ._chunksize = int (chunksize )
192
-
193
184
@property
194
185
def data_index (self ) -> Index :
195
186
data_index = self .obj .index
0 commit comments