Skip to content

Commit f4083c3

Browse files
ParfaitGJulianWgs
authored andcommitted
TYP: Add typing for remaining IO XML methods with conditional for lxml (pandas-dev#40340)
1 parent 0099438 commit f4083c3

File tree

4 files changed

+52
-48
lines changed

4 files changed

+52
-48
lines changed

pandas/io/formats/xml.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ class EtreeXMLFormatter(BaseXMLFormatter):
288288
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
289289
"""
290290

291-
def __init__(self, *args, **kwargs):
291+
def __init__(self, *args, **kwargs) -> None:
292292
super().__init__(*args, **kwargs)
293293

294294
self.validate_columns()
@@ -452,7 +452,7 @@ class LxmlXMLFormatter(BaseXMLFormatter):
452452
modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
453453
"""
454454

455-
def __init__(self, *args, **kwargs):
455+
def __init__(self, *args, **kwargs) -> None:
456456
super().__init__(*args, **kwargs)
457457

458458
self.validate_columns()

pandas/io/xml.py

+36-37
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
stylesheet,
112112
compression,
113113
storage_options,
114-
):
114+
) -> None:
115115
self.path_or_buffer = path_or_buffer
116116
self.xpath = xpath
117117
self.namespaces = namespaces
@@ -187,14 +187,13 @@ def _validate_names(self) -> None:
187187
"""
188188
raise AbstractMethodError(self)
189189

190-
def _parse_doc(self):
190+
def _parse_doc(self, raw_doc) -> bytes:
191191
"""
192-
Build tree from io.
192+
Build tree from path_or_buffer.
193193
194-
This method will parse io object into tree for parsing
195-
conditionally by its specific object type.
194+
This method will parse XML object into tree
195+
either from string/bytes or file location.
196196
"""
197-
198197
raise AbstractMethodError(self)
199198

200199

@@ -204,22 +203,18 @@ class _EtreeFrameParser(_XMLFrameParser):
204203
standard library XML module: `xml.etree.ElementTree`.
205204
"""
206205

207-
from xml.etree.ElementTree import (
208-
Element,
209-
ElementTree,
210-
)
211-
212-
def __init__(self, *args, **kwargs):
206+
def __init__(self, *args, **kwargs) -> None:
213207
super().__init__(*args, **kwargs)
214208

215209
def parse_data(self) -> List[Dict[str, Optional[str]]]:
210+
from xml.etree.ElementTree import XML
216211

217212
if self.stylesheet is not None:
218213
raise ValueError(
219214
"To use stylesheet, you need lxml installed and selected as parser."
220215
)
221216

222-
self.xml_doc = self._parse_doc()
217+
self.xml_doc = XML(self._parse_doc(self.path_or_buffer))
223218

224219
self._validate_path()
225220
self._validate_names()
@@ -356,14 +351,15 @@ def _validate_names(self) -> None:
356351
f"{type(self.names).__name__} is not a valid type for names"
357352
)
358353

359-
def _parse_doc(self) -> Union[Element, ElementTree]:
354+
def _parse_doc(self, raw_doc) -> bytes:
360355
from xml.etree.ElementTree import (
361356
XMLParser,
362357
parse,
358+
tostring,
363359
)
364360

365361
handle_data = get_data_from_filepath(
366-
filepath_or_buffer=self.path_or_buffer,
362+
filepath_or_buffer=raw_doc,
367363
encoding=self.encoding,
368364
compression=self.compression,
369365
storage_options=self.storage_options,
@@ -373,7 +369,7 @@ def _parse_doc(self) -> Union[Element, ElementTree]:
373369
curr_parser = XMLParser(encoding=self.encoding)
374370
r = parse(xml_data, parser=curr_parser)
375371

376-
return r
372+
return tostring(r.getroot())
377373

378374

379375
class _LxmlFrameParser(_XMLFrameParser):
@@ -383,7 +379,7 @@ class _LxmlFrameParser(_XMLFrameParser):
383379
XPath 1.0 and XSLT 1.0.
384380
"""
385381

386-
def __init__(self, *args, **kwargs):
382+
def __init__(self, *args, **kwargs) -> None:
387383
super().__init__(*args, **kwargs)
388384

389385
def parse_data(self) -> List[Dict[str, Optional[str]]]:
@@ -394,12 +390,13 @@ def parse_data(self) -> List[Dict[str, Optional[str]]]:
394390
validate xpath, names, optionally parse and run XSLT,
395391
and parse original or transformed XML and return specific nodes.
396392
"""
393+
from lxml.etree import XML
397394

398-
self.xml_doc = self._parse_doc(self.path_or_buffer)
395+
self.xml_doc = XML(self._parse_doc(self.path_or_buffer))
399396

400397
if self.stylesheet is not None:
401-
self.xsl_doc = self._parse_doc(self.stylesheet)
402-
self.xml_doc = self._transform_doc()
398+
self.xsl_doc = XML(self._parse_doc(self.stylesheet))
399+
self.xml_doc = XML(self._transform_doc())
403400

404401
self._validate_path()
405402
self._validate_names()
@@ -491,21 +488,6 @@ def _parse_nodes(self) -> List[Dict[str, Optional[str]]]:
491488

492489
return dicts
493490

494-
def _transform_doc(self):
495-
"""
496-
Transform original tree using stylesheet.
497-
498-
This method will transform original xml using XSLT script into
499-
am ideally flatter xml document for easier parsing and migration
500-
to Data Frame.
501-
"""
502-
from lxml.etree import XSLT
503-
504-
transformer = XSLT(self.xsl_doc)
505-
new_doc = transformer(self.xml_doc)
506-
507-
return new_doc
508-
509491
def _validate_path(self) -> None:
510492

511493
msg = (
@@ -553,11 +535,12 @@ def _validate_names(self) -> None:
553535
f"{type(self.names).__name__} is not a valid type for names"
554536
)
555537

556-
def _parse_doc(self, raw_doc):
538+
def _parse_doc(self, raw_doc) -> bytes:
557539
from lxml.etree import (
558540
XMLParser,
559541
fromstring,
560542
parse,
543+
tostring,
561544
)
562545

563546
handle_data = get_data_from_filepath(
@@ -577,7 +560,22 @@ def _parse_doc(self, raw_doc):
577560
else:
578561
doc = parse(xml_data, parser=curr_parser)
579562

580-
return doc
563+
return tostring(doc)
564+
565+
def _transform_doc(self) -> bytes:
566+
"""
567+
Transform original tree using stylesheet.
568+
569+
This method will transform original xml using XSLT script into
570+
am ideally flatter xml document for easier parsing and migration
571+
to Data Frame.
572+
"""
573+
from lxml.etree import XSLT
574+
575+
transformer = XSLT(self.xsl_doc)
576+
new_doc = transformer(self.xml_doc)
577+
578+
return bytes(new_doc)
581579

582580

583581
def get_data_from_filepath(
@@ -695,6 +693,7 @@ def _parse(
695693
"""
696694

697695
lxml = import_optional_dependency("lxml.etree", errors="ignore")
696+
698697
p: Union[_EtreeFrameParser, _LxmlFrameParser]
699698

700699
if parser == "lxml":

pandas/tests/io/xml/test_to_xml.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
StringIO,
44
)
55
import os
6-
import sys
76
from typing import Union
87

98
import numpy as np
109
import pytest
1110

11+
from pandas.compat import PY38
1212
import pandas.util._test_decorators as td
1313

1414
from pandas import DataFrame
@@ -364,8 +364,8 @@ def test_na_empty_elem_option(datapath, parser):
364364

365365

366366
@pytest.mark.skipif(
367-
sys.version_info < (3, 8),
368-
reason=("etree alpha ordered attributes <= py3.7"),
367+
not PY38,
368+
reason=("etree alpha ordered attributes < py 3.8"),
369369
)
370370
def test_attrs_cols_nan_output(datapath, parser):
371371
expected = """\
@@ -383,8 +383,8 @@ def test_attrs_cols_nan_output(datapath, parser):
383383

384384

385385
@pytest.mark.skipif(
386-
sys.version_info < (3, 8),
387-
reason=("etree alpha ordered attributes <= py3.7"),
386+
not PY38,
387+
reason=("etree alpha ordered attributes < py3.8"),
388388
)
389389
def test_attrs_cols_prefix(datapath, parser):
390390
expected = """\
@@ -541,8 +541,8 @@ def test_hierarchical_columns(datapath, parser):
541541

542542

543543
@pytest.mark.skipif(
544-
sys.version_info < (3, 8),
545-
reason=("etree alpha ordered attributes <= py3.7"),
544+
not PY38,
545+
reason=("etree alpha ordered attributes < py3.8"),
546546
)
547547
def test_hierarchical_attrs_columns(datapath, parser):
548548
expected = """\
@@ -614,8 +614,8 @@ def test_multi_index(datapath, parser):
614614

615615

616616
@pytest.mark.skipif(
617-
sys.version_info < (3, 8),
618-
reason=("etree alpha ordered attributes <= py3.7"),
617+
not PY38,
618+
reason=("etree alpha ordered attributes < py3.8"),
619619
)
620620
def test_multi_index_attrs_cols(datapath, parser):
621621
expected = """\

pandas/tests/io/xml/test_xml.py

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import pytest
1111

12+
from pandas.compat import PY38
1213
import pandas.util._test_decorators as td
1314

1415
from pandas import DataFrame
@@ -253,6 +254,10 @@ def test_parser_consistency_file(datapath):
253254
@tm.network
254255
@pytest.mark.slow
255256
@td.skip_if_no("lxml")
257+
@pytest.mark.skipif(
258+
not PY38,
259+
reason=("etree alpha ordered attributes < py3.8"),
260+
)
256261
def test_parser_consistency_url(datapath):
257262
url = (
258263
"https://data.cityofchicago.org/api/views/"

0 commit comments

Comments
 (0)