Skip to content

Commit eddd9f0

Browse files
jbrockmendeljreback
authored andcommitted
REF: ensure name and cname are always str (#29692)
1 parent 84fcbb8 commit eddd9f0

File tree

1 file changed

+72
-36
lines changed

1 file changed

+72
-36
lines changed

pandas/io/pytables.py

+72-36
Original file line numberDiff line numberDiff line change
@@ -1710,29 +1710,37 @@ class IndexCol:
17101710
is_data_indexable = True
17111711
_info_fields = ["freq", "tz", "index_name"]
17121712

1713+
name: str
1714+
cname: str
1715+
kind_attr: str
1716+
17131717
def __init__(
17141718
self,
1719+
name: str,
17151720
values=None,
17161721
kind=None,
17171722
typ=None,
1718-
cname=None,
1723+
cname: Optional[str] = None,
17191724
itemsize=None,
1720-
name=None,
17211725
axis=None,
1722-
kind_attr=None,
1726+
kind_attr: Optional[str] = None,
17231727
pos=None,
17241728
freq=None,
17251729
tz=None,
17261730
index_name=None,
17271731
**kwargs,
17281732
):
1733+
1734+
if not isinstance(name, str):
1735+
raise ValueError("`name` must be a str.")
1736+
17291737
self.values = values
17301738
self.kind = kind
17311739
self.typ = typ
17321740
self.itemsize = itemsize
17331741
self.name = name
1734-
self.cname = cname
1735-
self.kind_attr = kind_attr
1742+
self.cname = cname or name
1743+
self.kind_attr = kind_attr or f"{name}_kind"
17361744
self.axis = axis
17371745
self.pos = pos
17381746
self.freq = freq
@@ -1742,19 +1750,14 @@ def __init__(
17421750
self.meta = None
17431751
self.metadata = None
17441752

1745-
if name is not None:
1746-
self.set_name(name, kind_attr)
17471753
if pos is not None:
17481754
self.set_pos(pos)
17491755

1750-
def set_name(self, name, kind_attr=None):
1751-
""" set the name of this indexer """
1752-
self.name = name
1753-
self.kind_attr = kind_attr or "{name}_kind".format(name=name)
1754-
if self.cname is None:
1755-
self.cname = name
1756-
1757-
return self
1756+
# These are ensured as long as the passed arguments match the
1757+
# constructor annotations.
1758+
assert isinstance(self.name, str)
1759+
assert isinstance(self.cname, str)
1760+
assert isinstance(self.kind_attr, str)
17581761

17591762
def set_axis(self, axis: int):
17601763
""" set the axis over which I index """
@@ -1771,7 +1774,6 @@ def set_pos(self, pos: int):
17711774

17721775
def set_table(self, table):
17731776
self.table = table
1774-
return self
17751777

17761778
def __repr__(self) -> str:
17771779
temp = tuple(
@@ -1797,10 +1799,13 @@ def __ne__(self, other) -> bool:
17971799
@property
17981800
def is_indexed(self) -> bool:
17991801
""" return whether I am an indexed column """
1800-
try:
1801-
return getattr(self.table.cols, self.cname).is_indexed
1802-
except AttributeError:
1802+
if not hasattr(self.table, "cols"):
1803+
# e.g. if self.set_table hasn't been called yet, self.table
1804+
# will be None.
18031805
return False
1806+
# GH#29692 mypy doesn't recognize self.table as having a "cols" attribute
1807+
# 'error: "None" has no attribute "cols"'
1808+
return getattr(self.table.cols, self.cname).is_indexed # type: ignore
18041809

18051810
def copy(self):
18061811
new_self = copy.copy(self)
@@ -2508,6 +2513,7 @@ class DataIndexableCol(DataCol):
25082513

25092514
def validate_names(self):
25102515
if not Index(self.values).is_object():
2516+
# TODO: should the message here be more specifically non-str?
25112517
raise ValueError("cannot have non-object label DataIndexableCol")
25122518

25132519
def get_atom_string(self, block, itemsize):
@@ -2842,8 +2848,8 @@ def write_index(self, key, index):
28422848
else:
28432849
setattr(self.attrs, "{key}_variety".format(key=key), "regular")
28442850
converted = _convert_index(
2845-
index, self.encoding, self.errors, self.format_type
2846-
).set_name("index")
2851+
"index", index, self.encoding, self.errors, self.format_type
2852+
)
28472853

28482854
self.write_array(key, converted.values)
28492855

@@ -2893,8 +2899,8 @@ def write_multi_index(self, key, index):
28932899
)
28942900
level_key = "{key}_level{idx}".format(key=key, idx=i)
28952901
conv_level = _convert_index(
2896-
lev, self.encoding, self.errors, self.format_type
2897-
).set_name(level_key)
2902+
level_key, lev, self.encoding, self.errors, self.format_type
2903+
)
28982904
self.write_array(level_key, conv_level.values)
28992905
node = getattr(self.group, level_key)
29002906
node._v_attrs.kind = conv_level.kind
@@ -3436,9 +3442,10 @@ def queryables(self):
34363442

34373443
def index_cols(self):
34383444
""" return a list of my index cols """
3445+
# Note: each `i.cname` below is assured to be a str.
34393446
return [(i.axis, i.cname) for i in self.index_axes]
34403447

3441-
def values_cols(self):
3448+
def values_cols(self) -> List[str]:
34423449
""" return a list of my values cols """
34433450
return [i.cname for i in self.values_axes]
34443451

@@ -3540,6 +3547,8 @@ def indexables(self):
35403547

35413548
self._indexables = []
35423549

3550+
# Note: each of the `name` kwargs below are str, ensured
3551+
# by the definition in index_cols.
35433552
# index columns
35443553
self._indexables.extend(
35453554
[
@@ -3553,13 +3562,16 @@ def indexables(self):
35533562
base_pos = len(self._indexables)
35543563

35553564
def f(i, c):
3565+
assert isinstance(c, str)
35563566
klass = DataCol
35573567
if c in dc:
35583568
klass = DataIndexableCol
35593569
return klass.create_for_block(
35603570
i=i, name=c, pos=base_pos + i, version=self.version
35613571
)
35623572

3573+
# Note: the definition of `values_cols` ensures that each
3574+
# `c` below is a str.
35633575
self._indexables.extend(
35643576
[f(i, c) for i, c in enumerate(self.attrs.values_cols)]
35653577
)
@@ -3797,11 +3809,9 @@ def create_axes(
37973809

37983810
if i in axes:
37993811
name = obj._AXIS_NAMES[i]
3800-
index_axes_map[i] = (
3801-
_convert_index(a, self.encoding, self.errors, self.format_type)
3802-
.set_name(name)
3803-
.set_axis(i)
3804-
)
3812+
index_axes_map[i] = _convert_index(
3813+
name, a, self.encoding, self.errors, self.format_type
3814+
).set_axis(i)
38053815
else:
38063816

38073817
# we might be able to change the axes on the appending data if
@@ -3900,6 +3910,9 @@ def get_blk_items(mgr, blocks):
39003910
if data_columns and len(b_items) == 1 and b_items[0] in data_columns:
39013911
klass = DataIndexableCol
39023912
name = b_items[0]
3913+
if not (name is None or isinstance(name, str)):
3914+
# TODO: should the message here be more specifically non-str?
3915+
raise ValueError("cannot have non-object label DataIndexableCol")
39033916
self.data_columns.append(name)
39043917

39053918
# make sure that we match up the existing columns
@@ -4582,6 +4595,7 @@ def indexables(self):
45824595
self._indexables = [GenericIndexCol(name="index", axis=0)]
45834596

45844597
for i, n in enumerate(d._v_names):
4598+
assert isinstance(n, str)
45854599

45864600
dc = GenericDataIndexableCol(
45874601
name=n, pos=i, values=[n], version=self.version
@@ -4700,12 +4714,15 @@ def _set_tz(values, tz, preserve_UTC: bool = False, coerce: bool = False):
47004714
return values
47014715

47024716

4703-
def _convert_index(index, encoding=None, errors="strict", format_type=None):
4717+
def _convert_index(name: str, index, encoding=None, errors="strict", format_type=None):
4718+
assert isinstance(name, str)
4719+
47044720
index_name = getattr(index, "name", None)
47054721

47064722
if isinstance(index, DatetimeIndex):
47074723
converted = index.asi8
47084724
return IndexCol(
4725+
name,
47094726
converted,
47104727
"datetime64",
47114728
_tables().Int64Col(),
@@ -4716,6 +4733,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47164733
elif isinstance(index, TimedeltaIndex):
47174734
converted = index.asi8
47184735
return IndexCol(
4736+
name,
47194737
converted,
47204738
"timedelta64",
47214739
_tables().Int64Col(),
@@ -4726,6 +4744,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47264744
atom = _tables().Int64Col()
47274745
# avoid to store ndarray of Period objects
47284746
return IndexCol(
4747+
name,
47294748
index._ndarray_values,
47304749
"integer",
47314750
atom,
@@ -4743,6 +4762,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47434762
if inferred_type == "datetime64":
47444763
converted = values.view("i8")
47454764
return IndexCol(
4765+
name,
47464766
converted,
47474767
"datetime64",
47484768
_tables().Int64Col(),
@@ -4753,6 +4773,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47534773
elif inferred_type == "timedelta64":
47544774
converted = values.view("i8")
47554775
return IndexCol(
4776+
name,
47564777
converted,
47574778
"timedelta64",
47584779
_tables().Int64Col(),
@@ -4765,18 +4786,21 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47654786
dtype=np.float64,
47664787
)
47674788
return IndexCol(
4768-
converted, "datetime", _tables().Time64Col(), index_name=index_name
4789+
name, converted, "datetime", _tables().Time64Col(), index_name=index_name
47694790
)
47704791
elif inferred_type == "date":
47714792
converted = np.asarray([v.toordinal() for v in values], dtype=np.int32)
4772-
return IndexCol(converted, "date", _tables().Time32Col(), index_name=index_name)
4793+
return IndexCol(
4794+
name, converted, "date", _tables().Time32Col(), index_name=index_name,
4795+
)
47734796
elif inferred_type == "string":
47744797
# atom = _tables().ObjectAtom()
47754798
# return np.asarray(values, dtype='O'), 'object', atom
47764799

47774800
converted = _convert_string_array(values, encoding, errors)
47784801
itemsize = converted.dtype.itemsize
47794802
return IndexCol(
4803+
name,
47804804
converted,
47814805
"string",
47824806
_tables().StringCol(itemsize),
@@ -4787,7 +4811,11 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47874811
if format_type == "fixed":
47884812
atom = _tables().ObjectAtom()
47894813
return IndexCol(
4790-
np.asarray(values, dtype="O"), "object", atom, index_name=index_name
4814+
name,
4815+
np.asarray(values, dtype="O"),
4816+
"object",
4817+
atom,
4818+
index_name=index_name,
47914819
)
47924820
raise TypeError(
47934821
"[unicode] is not supported as a in index type for [{0}] formats".format(
@@ -4799,17 +4827,25 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
47994827
# take a guess for now, hope the values fit
48004828
atom = _tables().Int64Col()
48014829
return IndexCol(
4802-
np.asarray(values, dtype=np.int64), "integer", atom, index_name=index_name
4830+
name,
4831+
np.asarray(values, dtype=np.int64),
4832+
"integer",
4833+
atom,
4834+
index_name=index_name,
48034835
)
48044836
elif inferred_type == "floating":
48054837
atom = _tables().Float64Col()
48064838
return IndexCol(
4807-
np.asarray(values, dtype=np.float64), "float", atom, index_name=index_name
4839+
name,
4840+
np.asarray(values, dtype=np.float64),
4841+
"float",
4842+
atom,
4843+
index_name=index_name,
48084844
)
48094845
else: # pragma: no cover
48104846
atom = _tables().ObjectAtom()
48114847
return IndexCol(
4812-
np.asarray(values, dtype="O"), "object", atom, index_name=index_name
4848+
name, np.asarray(values, dtype="O"), "object", atom, index_name=index_name,
48134849
)
48144850

48154851

0 commit comments

Comments
 (0)