Skip to content

TYP: Add typing for remaining IO XML methods with conditional for lxml #40340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 16, 2021
Merged
4 changes: 2 additions & 2 deletions pandas/io/formats/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class EtreeXMLFormatter(BaseXMLFormatter):
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.validate_columns()
Expand Down Expand Up @@ -452,7 +452,7 @@ class LxmlXMLFormatter(BaseXMLFormatter):
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.validate_columns()
Expand Down
73 changes: 36 additions & 37 deletions pandas/io/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
stylesheet,
compression,
storage_options,
):
) -> None:
self.path_or_buffer = path_or_buffer
self.xpath = xpath
self.namespaces = namespaces
Expand Down Expand Up @@ -187,14 +187,13 @@ def _validate_names(self) -> None:
"""
raise AbstractMethodError(self)

def _parse_doc(self):
def _parse_doc(self, raw_doc) -> bytes:
"""
Build tree from io.
Build tree from path_or_buffer.

This method will parse io object into tree for parsing
conditionally by its specific object type.
This method will parse XML object into tree
either from string/bytes or file location.
"""

raise AbstractMethodError(self)


Expand All @@ -204,22 +203,18 @@ class _EtreeFrameParser(_XMLFrameParser):
standard library XML module: `xml.etree.ElementTree`.
"""

from xml.etree.ElementTree import (
Element,
ElementTree,
)

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def parse_data(self) -> List[Dict[str, Optional[str]]]:
from xml.etree.ElementTree import XML

if self.stylesheet is not None:
raise ValueError(
"To use stylesheet, you need lxml installed and selected as parser."
)

self.xml_doc = self._parse_doc()
self.xml_doc = XML(self._parse_doc(self.path_or_buffer))

self._validate_path()
self._validate_names()
Expand Down Expand Up @@ -356,14 +351,15 @@ def _validate_names(self) -> None:
f"{type(self.names).__name__} is not a valid type for names"
)

def _parse_doc(self) -> Union[Element, ElementTree]:
def _parse_doc(self, raw_doc) -> bytes:
from xml.etree.ElementTree import (
XMLParser,
parse,
tostring,
)

handle_data = get_data_from_filepath(
filepath_or_buffer=self.path_or_buffer,
filepath_or_buffer=raw_doc,
encoding=self.encoding,
compression=self.compression,
storage_options=self.storage_options,
Expand All @@ -373,7 +369,7 @@ def _parse_doc(self) -> Union[Element, ElementTree]:
curr_parser = XMLParser(encoding=self.encoding)
r = parse(xml_data, parser=curr_parser)

return r
return tostring(r.getroot())


class _LxmlFrameParser(_XMLFrameParser):
Expand All @@ -383,7 +379,7 @@ class _LxmlFrameParser(_XMLFrameParser):
XPath 1.0 and XSLT 1.0.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def parse_data(self) -> List[Dict[str, Optional[str]]]:
Expand All @@ -394,12 +390,13 @@ def parse_data(self) -> List[Dict[str, Optional[str]]]:
validate xpath, names, optionally parse and run XSLT,
and parse original or transformed XML and return specific nodes.
"""
from lxml.etree import XML

self.xml_doc = self._parse_doc(self.path_or_buffer)
self.xml_doc = XML(self._parse_doc(self.path_or_buffer))

if self.stylesheet is not None:
self.xsl_doc = self._parse_doc(self.stylesheet)
self.xml_doc = self._transform_doc()
self.xsl_doc = XML(self._parse_doc(self.stylesheet))
self.xml_doc = XML(self._transform_doc())

self._validate_path()
self._validate_names()
Expand Down Expand Up @@ -491,21 +488,6 @@ def _parse_nodes(self) -> List[Dict[str, Optional[str]]]:

return dicts

def _transform_doc(self):
"""
Transform original tree using stylesheet.

This method will transform original xml using XSLT script into
am ideally flatter xml document for easier parsing and migration
to Data Frame.
"""
from lxml.etree import XSLT

transformer = XSLT(self.xsl_doc)
new_doc = transformer(self.xml_doc)

return new_doc

def _validate_path(self) -> None:

msg = (
Expand Down Expand Up @@ -553,11 +535,12 @@ def _validate_names(self) -> None:
f"{type(self.names).__name__} is not a valid type for names"
)

def _parse_doc(self, raw_doc):
def _parse_doc(self, raw_doc) -> bytes:
from lxml.etree import (
XMLParser,
fromstring,
parse,
tostring,
)

handle_data = get_data_from_filepath(
Expand All @@ -577,7 +560,22 @@ def _parse_doc(self, raw_doc):
else:
doc = parse(xml_data, parser=curr_parser)

return doc
return tostring(doc)

def _transform_doc(self) -> bytes:
"""
Transform original tree using stylesheet.

This method will transform original xml using XSLT script into
am ideally flatter xml document for easier parsing and migration
to Data Frame.
"""
from lxml.etree import XSLT

transformer = XSLT(self.xsl_doc)
new_doc = transformer(self.xml_doc)

return bytes(new_doc)


def get_data_from_filepath(
Expand Down Expand Up @@ -695,6 +693,7 @@ def _parse(
"""

lxml = import_optional_dependency("lxml.etree", errors="ignore")

p: Union[_EtreeFrameParser, _LxmlFrameParser]

if parser == "lxml":
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/io/xml/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
StringIO,
)
import os
import sys
from typing import Union
from urllib.error import HTTPError

Expand Down Expand Up @@ -253,6 +254,10 @@ def test_parser_consistency_file(datapath):
@tm.network
@pytest.mark.slow
@td.skip_if_no("lxml")
@pytest.mark.skipif(
sys.version_info < (3, 8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use the PY38 from compat

reason=("etree alpha ordered attributes <= py3.7"),
)
def test_parser_consistency_url(datapath):
url = (
"https://data.cityofchicago.org/api/views/"
Expand Down