diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 252534a0cb790..5b76e6ad29321 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -63,6 +63,7 @@ IndexLabel, Level, PythonFuncType, + ReadBuffer, Renamer, Scalar, StorageOptions, @@ -2948,15 +2949,15 @@ def to_xml( root_name: str | None = "data", row_name: str | None = "row", na_rep: str | None = None, - attr_cols: str | list[str] | None = None, - elem_cols: str | list[str] | None = None, + attr_cols: list[str] | None = None, + elem_cols: list[str] | None = None, namespaces: dict[str | None, str] | None = None, prefix: str | None = None, encoding: str = "utf-8", xml_declaration: bool | None = True, pretty_print: bool | None = True, parser: str | None = "lxml", - stylesheet: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None, + stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ) -> str | None: diff --git a/pandas/io/formats/xml.py b/pandas/io/formats/xml.py index 1b11bb12757bb..aa69792cb1db0 100644 --- a/pandas/io/formats/xml.py +++ b/pandas/io/formats/xml.py @@ -96,8 +96,8 @@ class BaseXMLFormatter: def __init__( self, frame: DataFrame, - path_or_buffer: FilePath | WriteBuffer[bytes] | None = None, - index: bool | None = True, + path_or_buffer: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None, + index: bool = True, root_name: str | None = "data", row_name: str | None = "row", na_rep: str | None = None, @@ -108,7 +108,7 @@ def __init__( encoding: str = "utf-8", xml_declaration: bool | None = True, pretty_print: bool | None = True, - stylesheet: FilePath | ReadBuffer[str] | None = None, + stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ) -> None: @@ -132,6 +132,11 @@ def __init__( self.orig_cols = self.frame.columns.tolist() self.frame_dicts = self.process_dataframe() + self.validate_columns() + self.validate_encoding() + self.prefix_uri = self.get_prefix_uri() + self.handle_indexes() + def build_tree(self) -> bytes: """ Build tree from data. @@ -189,8 +194,8 @@ def process_dataframe(self) -> dict[int | str, dict[str, Any]]: if self.index: df = df.reset_index() - if self.na_rep: - df = df.replace({None: self.na_rep, float("nan"): self.na_rep}) + if self.na_rep is not None: + df = df.fillna(self.na_rep) return df.to_dict(orient="index") @@ -247,7 +252,7 @@ def other_namespaces(self) -> dict: return nmsp_dict - def build_attribs(self) -> None: + def build_attribs(self, d: dict[str, Any], elem_row: Any) -> Any: """ Create attributes of row. @@ -255,9 +260,29 @@ def build_attribs(self) -> None: works with tuples for multindex or hierarchical columns. """ - raise AbstractMethodError(self) + if not self.attr_cols: + return elem_row + + for col in self.attr_cols: + attr_name = self._get_flat_col_name(col) + try: + if not isna(d[col]): + elem_row.attrib[attr_name] = str(d[col]) + except KeyError: + raise KeyError(f"no valid column, {col}") + return elem_row + + def _get_flat_col_name(self, col: str | tuple) -> str: + flat_col = col + if isinstance(col, tuple): + flat_col = ( + "".join([str(c) for c in col]).strip() + if "" in col + else "_".join([str(c) for c in col]).strip() + ) + return f"{self.prefix_uri}{flat_col}" - def build_elems(self) -> None: + def build_elems(self, d: dict[str, Any], elem_row: Any) -> None: """ Create child elements of row. @@ -267,6 +292,19 @@ def build_elems(self) -> None: raise AbstractMethodError(self) + def _build_elems(self, sub_element_cls, d: dict[str, Any], elem_row: Any) -> None: + + if not self.elem_cols: + return + + for col in self.elem_cols: + elem_name = self._get_flat_col_name(col) + try: + val = None if isna(d[col]) or d[col] == "" else str(d[col]) + sub_element_cls(elem_row, elem_name).text = val + except KeyError: + raise KeyError(f"no valid column, {col}") + def write_output(self) -> str | None: xml_doc = self.build_tree() @@ -291,14 +329,6 @@ class EtreeXMLFormatter(BaseXMLFormatter): modules: `xml.etree.ElementTree` and `xml.dom.minidom`. """ - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self.validate_columns() - self.validate_encoding() - self.handle_indexes() - self.prefix_uri = self.get_prefix_uri() - def build_tree(self) -> bytes: from xml.etree.ElementTree import ( Element, @@ -311,16 +341,15 @@ def build_tree(self) -> bytes: ) for d in self.frame_dicts.values(): - self.d = d - self.elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}") + elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}") if not self.attr_cols and not self.elem_cols: - self.elem_cols = list(self.d.keys()) - self.build_elems() + self.elem_cols = list(d.keys()) + self.build_elems(d, elem_row) else: - self.build_attribs() - self.build_elems() + elem_row = self.build_attribs(d, elem_row) + self.build_elems(d, elem_row) self.out_xml = tostring(self.root, method="xml", encoding=self.encoding) @@ -357,56 +386,10 @@ def get_prefix_uri(self) -> str: return uri - def build_attribs(self) -> None: - if not self.attr_cols: - return - - for col in self.attr_cols: - flat_col = col - if isinstance(col, tuple): - flat_col = ( - "".join([str(c) for c in col]).strip() - if "" in col - else "_".join([str(c) for c in col]).strip() - ) - - attr_name = f"{self.prefix_uri}{flat_col}" - try: - val = ( - None - if self.d[col] is None or self.d[col] != self.d[col] - else str(self.d[col]) - ) - if val is not None: - self.elem_row.attrib[attr_name] = val - except KeyError: - raise KeyError(f"no valid column, {col}") - - def build_elems(self) -> None: + def build_elems(self, d: dict[str, Any], elem_row: Any) -> None: from xml.etree.ElementTree import SubElement - if not self.elem_cols: - return - - for col in self.elem_cols: - flat_col = col - if isinstance(col, tuple): - flat_col = ( - "".join([str(c) for c in col]).strip() - if "" in col - else "_".join([str(c) for c in col]).strip() - ) - - elem_name = f"{self.prefix_uri}{flat_col}" - try: - val = ( - None - if self.d[col] in [None, ""] or self.d[col] != self.d[col] - else str(self.d[col]) - ) - SubElement(self.elem_row, elem_name).text = val - except KeyError: - raise KeyError(f"no valid column, {col}") + self._build_elems(SubElement, d, elem_row) def prettify_tree(self) -> bytes: """ @@ -458,12 +441,7 @@ class LxmlXMLFormatter(BaseXMLFormatter): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.validate_columns() - self.validate_encoding() - self.prefix_uri = self.get_prefix_uri() - self.convert_empty_str_key() - self.handle_indexes() def build_tree(self) -> bytes: """ @@ -481,16 +459,15 @@ def build_tree(self) -> bytes: self.root = Element(f"{self.prefix_uri}{self.root_name}", nsmap=self.namespaces) for d in self.frame_dicts.values(): - self.d = d - self.elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}") + elem_row = SubElement(self.root, f"{self.prefix_uri}{self.row_name}") if not self.attr_cols and not self.elem_cols: - self.elem_cols = list(self.d.keys()) - self.build_elems() + self.elem_cols = list(d.keys()) + self.build_elems(d, elem_row) else: - self.build_attribs() - self.build_elems() + elem_row = self.build_attribs(d, elem_row) + self.build_elems(d, elem_row) self.out_xml = tostring( self.root, @@ -529,54 +506,10 @@ def get_prefix_uri(self) -> str: return uri - def build_attribs(self) -> None: - if not self.attr_cols: - return - - for col in self.attr_cols: - flat_col = col - if isinstance(col, tuple): - flat_col = ( - "".join([str(c) for c in col]).strip() - if "" in col - else "_".join([str(c) for c in col]).strip() - ) - - attr_name = f"{self.prefix_uri}{flat_col}" - try: - val = ( - None - if self.d[col] is None or self.d[col] != self.d[col] - else str(self.d[col]) - ) - if val is not None: - self.elem_row.attrib[attr_name] = val - except KeyError: - raise KeyError(f"no valid column, {col}") - - def build_elems(self) -> None: + def build_elems(self, d: dict[str, Any], elem_row: Any) -> None: from lxml.etree import SubElement - if not self.elem_cols: - return - - for col in self.elem_cols: - flat_col = col - if isinstance(col, tuple): - flat_col = ( - "".join([str(c) for c in col]).strip() - if "" in col - else "_".join([str(c) for c in col]).strip() - ) - - elem_name = f"{self.prefix_uri}{flat_col}" - try: - val = ( - None if isna(self.d[col]) or self.d[col] == "" else str(self.d[col]) - ) - SubElement(self.elem_row, elem_name).text = val - except KeyError: - raise KeyError(f"no valid column, {col}") + self._build_elems(SubElement, d, elem_row) def transform_doc(self) -> bytes: """ diff --git a/pandas/tests/io/xml/test_to_xml.py b/pandas/tests/io/xml/test_to_xml.py index c8828c08dba44..aeec163ed134a 100644 --- a/pandas/tests/io/xml/test_to_xml.py +++ b/pandas/tests/io/xml/test_to_xml.py @@ -1308,8 +1308,7 @@ def test_filename_and_suffix_comp(parser, compression_only): assert geom_xml == output.strip() -@td.skip_if_no("lxml") -def test_ea_dtypes(any_numeric_ea_dtype): +def test_ea_dtypes(any_numeric_ea_dtype, parser): # GH#43903 expected = """ @@ -1319,8 +1318,8 @@ def test_ea_dtypes(any_numeric_ea_dtype): """ df = DataFrame({"a": [NA]}).astype(any_numeric_ea_dtype) - result = df.to_xml() - assert result.strip() == expected + result = df.to_xml(parser=parser) + assert equalize_decl(result).strip() == expected def test_unsuported_compression(datapath, parser):