Skip to content

Commit cf5a14a

Browse files
jbrockmendelproost
authored andcommitted
REF: make pytables get_atom_data non-stateful (pandas-dev#30074)
1 parent e87275f commit cf5a14a

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

pandas/io/pytables.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from pandas.io.formats.printing import adjoin, pprint_thing
6666

6767
if TYPE_CHECKING:
68-
from tables import File, Node # noqa:F401
68+
from tables import File, Node, Col # noqa:F401
6969

7070

7171
# versioning attribute
@@ -1092,6 +1092,9 @@ def remove(self, key: str, where=None, start=None, stop=None):
10921092
except KeyError:
10931093
# the key is not a valid store, re-raising KeyError
10941094
raise
1095+
except AssertionError:
1096+
# surface any assertion errors for e.g. debugging
1097+
raise
10951098
except Exception:
10961099
# In tests we get here with ClosedFileError, TypeError, and
10971100
# _table_mod.NoSuchNodeError. TODO: Catch only these?
@@ -1519,6 +1522,9 @@ def info(self) -> str:
15191522
if s is not None:
15201523
keys.append(pprint_thing(s.pathname or k))
15211524
values.append(pprint_thing(s or "invalid_HDFStore node"))
1525+
except AssertionError:
1526+
# surface any assertion errors for e.g. debugging
1527+
raise
15221528
except Exception as detail:
15231529
keys.append(k)
15241530
dstr = pprint_thing(detail)
@@ -1680,7 +1686,7 @@ def _write_to_group(
16801686
self._handle.remove_node(group, recursive=True)
16811687
group = None
16821688

1683-
# we don't want to store a table node at all if are object is 0-len
1689+
# we don't want to store a table node at all if our object is 0-len
16841690
# as there are not dtypes
16851691
if getattr(value, "empty", None) and (format == "table" or append):
16861692
return
@@ -2356,11 +2362,9 @@ def set_atom_string(self, data_converted: np.ndarray):
23562362
self.typ = self.get_atom_string(data_converted.shape, itemsize)
23572363
self.set_data(data_converted)
23582364

2359-
def get_atom_coltype(self, kind=None):
2365+
def get_atom_coltype(self, kind: str) -> Type["Col"]:
23602366
""" return the PyTables column class for this column """
2361-
if kind is None:
2362-
kind = self.kind
2363-
if self.kind.startswith("uint"):
2367+
if kind.startswith("uint"):
23642368
k4 = kind[4:]
23652369
col_name = f"UInt{k4}Col"
23662370
else:
@@ -2369,8 +2373,8 @@ def get_atom_coltype(self, kind=None):
23692373

23702374
return getattr(_tables(), col_name)
23712375

2372-
def get_atom_data(self, block, kind=None):
2373-
return self.get_atom_coltype(kind=kind)(shape=block.shape[0])
2376+
def get_atom_data(self, shape, kind: str) -> "Col":
2377+
return self.get_atom_coltype(kind=kind)(shape=shape[0])
23742378

23752379
def set_atom_complex(self, block):
23762380
self.kind = block.dtype.name
@@ -2380,7 +2384,7 @@ def set_atom_complex(self, block):
23802384

23812385
def set_atom_data(self, block):
23822386
self.kind = block.dtype.name
2383-
self.typ = self.get_atom_data(block)
2387+
self.typ = self.get_atom_data(block.shape, kind=block.dtype.name)
23842388
self.set_data(block.values)
23852389

23862390
def set_atom_categorical(self, block):
@@ -2389,19 +2393,22 @@ def set_atom_categorical(self, block):
23892393

23902394
values = block.values
23912395
codes = values.codes
2392-
self.kind = "integer"
2393-
self.dtype = codes.dtype.name
2396+
23942397
if values.ndim > 1:
23952398
raise NotImplementedError("only support 1-d categoricals")
23962399

2400+
assert codes.dtype.name.startswith("int"), codes.dtype.name
2401+
23972402
# write the codes; must be in a block shape
23982403
self.ordered = values.ordered
2399-
self.typ = self.get_atom_data(block, kind=codes.dtype.name)
2404+
self.typ = self.get_atom_data(block.shape, kind=codes.dtype.name)
24002405
self.set_data(block.values)
24012406

24022407
# write the categories
24032408
self.meta = "category"
24042409
self.metadata = np.array(block.values.categories, copy=False).ravel()
2410+
assert self.kind == "integer", self.kind
2411+
assert self.dtype == codes.dtype.name, codes.dtype.name
24052412

24062413
def get_atom_datetime64(self, block):
24072414
return _tables().Int64Col(shape=block.shape[0])
@@ -2554,7 +2561,7 @@ def validate_names(self):
25542561
def get_atom_string(self, shape, itemsize):
25552562
return _tables().StringCol(itemsize=itemsize)
25562563

2557-
def get_atom_data(self, block, kind=None):
2564+
def get_atom_data(self, shape, kind: str) -> "Col":
25582565
return self.get_atom_coltype(kind=kind)()
25592566

25602567
def get_atom_datetime64(self, block):

0 commit comments

Comments
 (0)