Skip to content

Commit 9dfb454

Browse files
authored
REF: Deduplicate to_xml code (#45132)
1 parent abd7436 commit 9dfb454

File tree

3 files changed

+67
-134
lines changed

3 files changed

+67
-134
lines changed

pandas/core/frame.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
IndexLabel,
6464
Level,
6565
PythonFuncType,
66+
ReadBuffer,
6667
Renamer,
6768
Scalar,
6869
StorageOptions,
@@ -2948,15 +2949,15 @@ def to_xml(
29482949
root_name: str | None = "data",
29492950
row_name: str | None = "row",
29502951
na_rep: str | None = None,
2951-
attr_cols: str | list[str] | None = None,
2952-
elem_cols: str | list[str] | None = None,
2952+
attr_cols: list[str] | None = None,
2953+
elem_cols: list[str] | None = None,
29532954
namespaces: dict[str | None, str] | None = None,
29542955
prefix: str | None = None,
29552956
encoding: str = "utf-8",
29562957
xml_declaration: bool | None = True,
29572958
pretty_print: bool | None = True,
29582959
parser: str | None = "lxml",
2959-
stylesheet: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None,
2960+
stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None,
29602961
compression: CompressionOptions = "infer",
29612962
storage_options: StorageOptions = None,
29622963
) -> str | None:

pandas/io/formats/xml.py

+60-127
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class BaseXMLFormatter:
9696
def __init__(
9797
self,
9898
frame: DataFrame,
99-
path_or_buffer: FilePath | WriteBuffer[bytes] | None = None,
100-
index: bool | None = True,
99+
path_or_buffer: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None,
100+
index: bool = True,
101101
root_name: str | None = "data",
102102
row_name: str | None = "row",
103103
na_rep: str | None = None,
@@ -108,7 +108,7 @@ def __init__(
108108
encoding: str = "utf-8",
109109
xml_declaration: bool | None = True,
110110
pretty_print: bool | None = True,
111-
stylesheet: FilePath | ReadBuffer[str] | None = None,
111+
stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None,
112112
compression: CompressionOptions = "infer",
113113
storage_options: StorageOptions = None,
114114
) -> None:
@@ -132,6 +132,11 @@ def __init__(
132132
self.orig_cols = self.frame.columns.tolist()
133133
self.frame_dicts = self.process_dataframe()
134134

135+
self.validate_columns()
136+
self.validate_encoding()
137+
self.prefix_uri = self.get_prefix_uri()
138+
self.handle_indexes()
139+
135140
def build_tree(self) -> bytes:
136141
"""
137142
Build tree from data.
@@ -189,8 +194,8 @@ def process_dataframe(self) -> dict[int | str, dict[str, Any]]:
189194
if self.index:
190195
df = df.reset_index()
191196

192-
if self.na_rep:
193-
df = df.replace({None: self.na_rep, float("nan"): self.na_rep})
197+
if self.na_rep is not None:
198+
df = df.fillna(self.na_rep)
194199

195200
return df.to_dict(orient="index")
196201

@@ -247,17 +252,37 @@ def other_namespaces(self) -> dict:
247252

248253
return nmsp_dict
249254

250-
def build_attribs(self) -> None:
255+
def build_attribs(self, d: dict[str, Any], elem_row: Any) -> Any:
251256
"""
252257
Create attributes of row.
253258
254259
This method adds attributes using attr_cols to row element and
255260
works with tuples for multindex or hierarchical columns.
256261
"""
257262

258-
raise AbstractMethodError(self)
263+
if not self.attr_cols:
264+
return elem_row
265+
266+
for col in self.attr_cols:
267+
attr_name = self._get_flat_col_name(col)
268+
try:
269+
if not isna(d[col]):
270+
elem_row.attrib[attr_name] = str(d[col])
271+
except KeyError:
272+
raise KeyError(f"no valid column, {col}")
273+
return elem_row
274+
275+
def _get_flat_col_name(self, col: str | tuple) -> str:
276+
flat_col = col
277+
if isinstance(col, tuple):
278+
flat_col = (
279+
"".join([str(c) for c in col]).strip()
280+
if "" in col
281+
else "_".join([str(c) for c in col]).strip()
282+
)
283+
return f"{self.prefix_uri}{flat_col}"
259284

260-
def build_elems(self) -> None:
285+
def build_elems(self, d: dict[str, Any], elem_row: Any) -> None:
261286
"""
262287
Create child elements of row.
263288
@@ -267,6 +292,19 @@ def build_elems(self) -> None:
267292

268293
raise AbstractMethodError(self)
269294

295+
def _build_elems(self, sub_element_cls, d: dict[str, Any], elem_row: Any) -> None:
296+
297+
if not self.elem_cols:
298+
return
299+
300+
for col in self.elem_cols:
301+
elem_name = self._get_flat_col_name(col)
302+
try:
303+
val = None if isna(d[col]) or d[col] == "" else str(d[col])
304+
sub_element_cls(elem_row, elem_name).text = val
305+
except KeyError:
306+
raise KeyError(f"no valid column, {col}")
307+
270308
def write_output(self) -> str | None:
271309
xml_doc = self.build_tree()
272310

@@ -291,14 +329,6 @@ class EtreeXMLFormatter(BaseXMLFormatter):
291329
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
292330
"""
293331

294-
def __init__(self, *args, **kwargs) -> None:
295-
super().__init__(*args, **kwargs)
296-
297-
self.validate_columns()
298-
self.validate_encoding()
299-
self.handle_indexes()
300-
self.prefix_uri = self.get_prefix_uri()
301-
302332
def build_tree(self) -> bytes:
303333
from xml.etree.ElementTree import (
304334
Element,
@@ -311,16 +341,15 @@ def build_tree(self) -> bytes:
311341
)
312342

313343
for d in self.frame_dicts.values():
314-
self.d = d
315-
self.elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}")
344+
elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}")
316345

317346
if not self.attr_cols and not self.elem_cols:
318-
self.elem_cols = list(self.d.keys())
319-
self.build_elems()
347+
self.elem_cols = list(d.keys())
348+
self.build_elems(d, elem_row)
320349

321350
else:
322-
self.build_attribs()
323-
self.build_elems()
351+
elem_row = self.build_attribs(d, elem_row)
352+
self.build_elems(d, elem_row)
324353

325354
self.out_xml = tostring(self.root, method="xml", encoding=self.encoding)
326355

@@ -357,56 +386,10 @@ def get_prefix_uri(self) -> str:
357386

358387
return uri
359388

360-
def build_attribs(self) -> None:
361-
if not self.attr_cols:
362-
return
363-
364-
for col in self.attr_cols:
365-
flat_col = col
366-
if isinstance(col, tuple):
367-
flat_col = (
368-
"".join([str(c) for c in col]).strip()
369-
if "" in col
370-
else "_".join([str(c) for c in col]).strip()
371-
)
372-
373-
attr_name = f"{self.prefix_uri}{flat_col}"
374-
try:
375-
val = (
376-
None
377-
if self.d[col] is None or self.d[col] != self.d[col]
378-
else str(self.d[col])
379-
)
380-
if val is not None:
381-
self.elem_row.attrib[attr_name] = val
382-
except KeyError:
383-
raise KeyError(f"no valid column, {col}")
384-
385-
def build_elems(self) -> None:
389+
def build_elems(self, d: dict[str, Any], elem_row: Any) -> None:
386390
from xml.etree.ElementTree import SubElement
387391

388-
if not self.elem_cols:
389-
return
390-
391-
for col in self.elem_cols:
392-
flat_col = col
393-
if isinstance(col, tuple):
394-
flat_col = (
395-
"".join([str(c) for c in col]).strip()
396-
if "" in col
397-
else "_".join([str(c) for c in col]).strip()
398-
)
399-
400-
elem_name = f"{self.prefix_uri}{flat_col}"
401-
try:
402-
val = (
403-
None
404-
if self.d[col] in [None, ""] or self.d[col] != self.d[col]
405-
else str(self.d[col])
406-
)
407-
SubElement(self.elem_row, elem_name).text = val
408-
except KeyError:
409-
raise KeyError(f"no valid column, {col}")
392+
self._build_elems(SubElement, d, elem_row)
410393

411394
def prettify_tree(self) -> bytes:
412395
"""
@@ -458,12 +441,7 @@ class LxmlXMLFormatter(BaseXMLFormatter):
458441
def __init__(self, *args, **kwargs) -> None:
459442
super().__init__(*args, **kwargs)
460443

461-
self.validate_columns()
462-
self.validate_encoding()
463-
self.prefix_uri = self.get_prefix_uri()
464-
465444
self.convert_empty_str_key()
466-
self.handle_indexes()
467445

468446
def build_tree(self) -> bytes:
469447
"""
@@ -481,16 +459,15 @@ def build_tree(self) -> bytes:
481459
self.root = Element(f"{self.prefix_uri}{self.root_name}", nsmap=self.namespaces)
482460

483461
for d in self.frame_dicts.values():
484-
self.d = d
485-
self.elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}")
462+
elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}")
486463

487464
if not self.attr_cols and not self.elem_cols:
488-
self.elem_cols = list(self.d.keys())
489-
self.build_elems()
465+
self.elem_cols = list(d.keys())
466+
self.build_elems(d, elem_row)
490467

491468
else:
492-
self.build_attribs()
493-
self.build_elems()
469+
elem_row = self.build_attribs(d, elem_row)
470+
self.build_elems(d, elem_row)
494471

495472
self.out_xml = tostring(
496473
self.root,
@@ -529,54 +506,10 @@ def get_prefix_uri(self) -> str:
529506

530507
return uri
531508

532-
def build_attribs(self) -> None:
533-
if not self.attr_cols:
534-
return
535-
536-
for col in self.attr_cols:
537-
flat_col = col
538-
if isinstance(col, tuple):
539-
flat_col = (
540-
"".join([str(c) for c in col]).strip()
541-
if "" in col
542-
else "_".join([str(c) for c in col]).strip()
543-
)
544-
545-
attr_name = f"{self.prefix_uri}{flat_col}"
546-
try:
547-
val = (
548-
None
549-
if self.d[col] is None or self.d[col] != self.d[col]
550-
else str(self.d[col])
551-
)
552-
if val is not None:
553-
self.elem_row.attrib[attr_name] = val
554-
except KeyError:
555-
raise KeyError(f"no valid column, {col}")
556-
557-
def build_elems(self) -> None:
509+
def build_elems(self, d: dict[str, Any], elem_row: Any) -> None:
558510
from lxml.etree import SubElement
559511

560-
if not self.elem_cols:
561-
return
562-
563-
for col in self.elem_cols:
564-
flat_col = col
565-
if isinstance(col, tuple):
566-
flat_col = (
567-
"".join([str(c) for c in col]).strip()
568-
if "" in col
569-
else "_".join([str(c) for c in col]).strip()
570-
)
571-
572-
elem_name = f"{self.prefix_uri}{flat_col}"
573-
try:
574-
val = (
575-
None if isna(self.d[col]) or self.d[col] == "" else str(self.d[col])
576-
)
577-
SubElement(self.elem_row, elem_name).text = val
578-
except KeyError:
579-
raise KeyError(f"no valid column, {col}")
512+
self._build_elems(SubElement, d, elem_row)
580513

581514
def transform_doc(self) -> bytes:
582515
"""

pandas/tests/io/xml/test_to_xml.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1308,8 +1308,7 @@ def test_filename_and_suffix_comp(parser, compression_only):
13081308
assert geom_xml == output.strip()
13091309

13101310

1311-
@td.skip_if_no("lxml")
1312-
def test_ea_dtypes(any_numeric_ea_dtype):
1311+
def test_ea_dtypes(any_numeric_ea_dtype, parser):
13131312
# GH#43903
13141313
expected = """<?xml version='1.0' encoding='utf-8'?>
13151314
<data>
@@ -1319,8 +1318,8 @@ def test_ea_dtypes(any_numeric_ea_dtype):
13191318
</row>
13201319
</data>"""
13211320
df = DataFrame({"a": [NA]}).astype(any_numeric_ea_dtype)
1322-
result = df.to_xml()
1323-
assert result.strip() == expected
1321+
result = df.to_xml(parser=parser)
1322+
assert equalize_decl(result).strip() == expected
13241323

13251324

13261325
def test_unsuported_compression(datapath, parser):

0 commit comments

Comments
 (0)