Skip to content

TYP: Add typing for remaining IO XML methods with conditional for lxml #40340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 16, 2021
Merged
4 changes: 2 additions & 2 deletions pandas/io/formats/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class EtreeXMLFormatter(BaseXMLFormatter):
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.validate_columns()
Expand Down Expand Up @@ -452,7 +452,7 @@ class LxmlXMLFormatter(BaseXMLFormatter):
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.validate_columns()
Expand Down
73 changes: 36 additions & 37 deletions pandas/io/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
stylesheet,
compression,
storage_options,
):
) -> None:
self.path_or_buffer = path_or_buffer
self.xpath = xpath
self.namespaces = namespaces
Expand Down Expand Up @@ -187,14 +187,13 @@ def _validate_names(self) -> None:
"""
raise AbstractMethodError(self)

def _parse_doc(self):
def _parse_doc(self, raw_doc) -> bytes:
"""
Build tree from io.
Build tree from path_or_buffer.

This method will parse io object into tree for parsing
conditionally by its specific object type.
This method will parse XML object into tree
either from string/bytes or file location.
"""

raise AbstractMethodError(self)


Expand All @@ -204,22 +203,18 @@ class _EtreeFrameParser(_XMLFrameParser):
standard library XML module: `xml.etree.ElementTree`.
"""

from xml.etree.ElementTree import (
Element,
ElementTree,
)

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def parse_data(self) -> List[Dict[str, Optional[str]]]:
from xml.etree.ElementTree import XML

if self.stylesheet is not None:
raise ValueError(
"To use stylesheet, you need lxml installed and selected as parser."
)

self.xml_doc = self._parse_doc()
self.xml_doc = XML(self._parse_doc(self.path_or_buffer))

self._validate_path()
self._validate_names()
Expand Down Expand Up @@ -356,14 +351,15 @@ def _validate_names(self) -> None:
f"{type(self.names).__name__} is not a valid type for names"
)

def _parse_doc(self) -> Union[Element, ElementTree]:
def _parse_doc(self, raw_doc) -> bytes:
from xml.etree.ElementTree import (
XMLParser,
parse,
tostring,
)

handle_data = get_data_from_filepath(
filepath_or_buffer=self.path_or_buffer,
filepath_or_buffer=raw_doc,
encoding=self.encoding,
compression=self.compression,
storage_options=self.storage_options,
Expand All @@ -373,7 +369,7 @@ def _parse_doc(self) -> Union[Element, ElementTree]:
curr_parser = XMLParser(encoding=self.encoding)
r = parse(xml_data, parser=curr_parser)

return r
return tostring(r.getroot())


class _LxmlFrameParser(_XMLFrameParser):
Expand All @@ -383,7 +379,7 @@ class _LxmlFrameParser(_XMLFrameParser):
XPath 1.0 and XSLT 1.0.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def parse_data(self) -> List[Dict[str, Optional[str]]]:
Expand All @@ -394,12 +390,13 @@ def parse_data(self) -> List[Dict[str, Optional[str]]]:
validate xpath, names, optionally parse and run XSLT,
and parse original or transformed XML and return specific nodes.
"""
from lxml.etree import XML

self.xml_doc = self._parse_doc(self.path_or_buffer)
self.xml_doc = XML(self._parse_doc(self.path_or_buffer))

if self.stylesheet is not None:
self.xsl_doc = self._parse_doc(self.stylesheet)
self.xml_doc = self._transform_doc()
self.xsl_doc = XML(self._parse_doc(self.stylesheet))
self.xml_doc = XML(self._transform_doc())

self._validate_path()
self._validate_names()
Expand Down Expand Up @@ -491,21 +488,6 @@ def _parse_nodes(self) -> List[Dict[str, Optional[str]]]:

return dicts

def _transform_doc(self):
"""
Transform original tree using stylesheet.

This method will transform original xml using XSLT script into
am ideally flatter xml document for easier parsing and migration
to Data Frame.
"""
from lxml.etree import XSLT

transformer = XSLT(self.xsl_doc)
new_doc = transformer(self.xml_doc)

return new_doc

def _validate_path(self) -> None:

msg = (
Expand Down Expand Up @@ -553,11 +535,12 @@ def _validate_names(self) -> None:
f"{type(self.names).__name__} is not a valid type for names"
)

def _parse_doc(self, raw_doc):
def _parse_doc(self, raw_doc) -> bytes:
from lxml.etree import (
XMLParser,
fromstring,
parse,
tostring,
)

handle_data = get_data_from_filepath(
Expand All @@ -577,7 +560,22 @@ def _parse_doc(self, raw_doc):
else:
doc = parse(xml_data, parser=curr_parser)

return doc
return tostring(doc)

def _transform_doc(self) -> bytes:
"""
Transform original tree using stylesheet.

This method will transform original xml using XSLT script into
am ideally flatter xml document for easier parsing and migration
to Data Frame.
"""
from lxml.etree import XSLT

transformer = XSLT(self.xsl_doc)
new_doc = transformer(self.xml_doc)

return bytes(new_doc)


def get_data_from_filepath(
Expand Down Expand Up @@ -695,6 +693,7 @@ def _parse(
"""

lxml = import_optional_dependency("lxml.etree", errors="ignore")

p: Union[_EtreeFrameParser, _LxmlFrameParser]

if parser == "lxml":
Expand Down
18 changes: 9 additions & 9 deletions pandas/tests/io/xml/test_to_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
StringIO,
)
import os
import sys
from typing import Union

import numpy as np
import pytest

from pandas.compat import PY38
import pandas.util._test_decorators as td

from pandas import DataFrame
Expand Down Expand Up @@ -364,8 +364,8 @@ def test_na_empty_elem_option(datapath, parser):


@pytest.mark.skipif(
sys.version_info < (3, 8),
reason=("etree alpha ordered attributes <= py3.7"),
not PY38,
reason=("etree alpha ordered attributes < py 3.8"),
)
def test_attrs_cols_nan_output(datapath, parser):
expected = """\
Expand All @@ -383,8 +383,8 @@ def test_attrs_cols_nan_output(datapath, parser):


@pytest.mark.skipif(
sys.version_info < (3, 8),
reason=("etree alpha ordered attributes <= py3.7"),
not PY38,
reason=("etree alpha ordered attributes < py3.8"),
)
def test_attrs_cols_prefix(datapath, parser):
expected = """\
Expand Down Expand Up @@ -541,8 +541,8 @@ def test_hierarchical_columns(datapath, parser):


@pytest.mark.skipif(
sys.version_info < (3, 8),
reason=("etree alpha ordered attributes <= py3.7"),
not PY38,
reason=("etree alpha ordered attributes < py3.8"),
)
def test_hierarchical_attrs_columns(datapath, parser):
expected = """\
Expand Down Expand Up @@ -614,8 +614,8 @@ def test_multi_index(datapath, parser):


@pytest.mark.skipif(
sys.version_info < (3, 8),
reason=("etree alpha ordered attributes <= py3.7"),
not PY38,
reason=("etree alpha ordered attributes < py3.8"),
)
def test_multi_index_attrs_cols(datapath, parser):
expected = """\
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/io/xml/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pytest

from pandas.compat import PY38
import pandas.util._test_decorators as td

from pandas import DataFrame
Expand Down Expand Up @@ -253,6 +254,10 @@ def test_parser_consistency_file(datapath):
@tm.network
@pytest.mark.slow
@td.skip_if_no("lxml")
@pytest.mark.skipif(
not PY38,
reason=("etree alpha ordered attributes < py3.8"),
)
def test_parser_consistency_url(datapath):
url = (
"https://data.cityofchicago.org/api/views/"
Expand Down