Skip to content

Commit 2877667

Browse files
jbrockmendelproost
authored andcommitted
REF: implement io.pytables.DataCol._get_atom (pandas-dev#30102)
1 parent 4d82e5d commit 2877667

File tree

1 file changed

+59
-22
lines changed

1 file changed

+59
-22
lines changed

pandas/io/pytables.py

+59-22
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
from pandas.core.dtypes.common import (
3535
ensure_object,
3636
is_categorical_dtype,
37+
is_complex_dtype,
3738
is_datetime64_dtype,
3839
is_datetime64tz_dtype,
3940
is_extension_array_dtype,
4041
is_list_like,
42+
is_string_dtype,
4143
is_timedelta64_dtype,
4244
)
4345
from pandas.core.dtypes.generic import ABCExtensionArray
@@ -2353,16 +2355,48 @@ def set_atom(self, block, data_converted, use_str: bool):
23532355
# set as a data block
23542356
self.set_atom_data(block)
23552357

2356-
def get_atom_string(self, shape, itemsize):
2358+
@classmethod
2359+
def _get_atom(cls, values: Union[np.ndarray, ABCExtensionArray]) -> "Col":
2360+
"""
2361+
Get an appropriately typed and shaped pytables.Col object for values.
2362+
"""
2363+
2364+
dtype = values.dtype
2365+
itemsize = dtype.itemsize
2366+
2367+
shape = values.shape
2368+
if values.ndim == 1:
2369+
# EA, use block shape pretending it is 2D
2370+
shape = (1, values.size)
2371+
2372+
if is_categorical_dtype(dtype):
2373+
codes = values.codes
2374+
atom = cls.get_atom_data(shape, kind=codes.dtype.name)
2375+
elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype):
2376+
atom = cls.get_atom_datetime64(shape)
2377+
elif is_timedelta64_dtype(dtype):
2378+
atom = cls.get_atom_timedelta64(shape)
2379+
elif is_complex_dtype(dtype):
2380+
atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0])
2381+
2382+
elif is_string_dtype(dtype):
2383+
atom = cls.get_atom_string(shape, itemsize)
2384+
2385+
else:
2386+
atom = cls.get_atom_data(shape, kind=dtype.name)
2387+
2388+
return atom
2389+
2390+
@classmethod
2391+
def get_atom_string(cls, shape, itemsize):
23572392
return _tables().StringCol(itemsize=itemsize, shape=shape[0])
23582393

23592394
def set_atom_string(self, data_converted: np.ndarray):
2360-
itemsize = data_converted.dtype.itemsize
23612395
self.kind = "string"
2362-
self.typ = self.get_atom_string(data_converted.shape, itemsize)
23632396
self.set_data(data_converted)
23642397

2365-
def get_atom_coltype(self, kind: str) -> Type["Col"]:
2398+
@classmethod
2399+
def get_atom_coltype(cls, kind: str) -> Type["Col"]:
23662400
""" return the PyTables column class for this column """
23672401
if kind.startswith("uint"):
23682402
k4 = kind[4:]
@@ -2373,18 +2407,16 @@ def get_atom_coltype(self, kind: str) -> Type["Col"]:
23732407

23742408
return getattr(_tables(), col_name)
23752409

2376-
def get_atom_data(self, shape, kind: str) -> "Col":
2377-
return self.get_atom_coltype(kind=kind)(shape=shape[0])
2410+
@classmethod
2411+
def get_atom_data(cls, shape, kind: str) -> "Col":
2412+
return cls.get_atom_coltype(kind=kind)(shape=shape[0])
23782413

23792414
def set_atom_complex(self, block):
23802415
self.kind = block.dtype.name
2381-
itemsize = int(self.kind.split("complex")[-1]) // 8
2382-
self.typ = _tables().ComplexCol(itemsize=itemsize, shape=block.shape[0])
23832416
self.set_data(block.values)
23842417

23852418
def set_atom_data(self, block):
23862419
self.kind = block.dtype.name
2387-
self.typ = self.get_atom_data(block.shape, kind=block.dtype.name)
23882420
self.set_data(block.values)
23892421

23902422
def set_atom_categorical(self, block):
@@ -2401,7 +2433,6 @@ def set_atom_categorical(self, block):
24012433

24022434
# write the codes; must be in a block shape
24032435
self.ordered = values.ordered
2404-
self.typ = self.get_atom_data(block.shape, kind=codes.dtype.name)
24052436
self.set_data(block.values)
24062437

24072438
# write the categories
@@ -2410,12 +2441,12 @@ def set_atom_categorical(self, block):
24102441
assert self.kind == "integer", self.kind
24112442
assert self.dtype == codes.dtype.name, codes.dtype.name
24122443

2413-
def get_atom_datetime64(self, block):
2414-
return _tables().Int64Col(shape=block.shape[0])
2444+
@classmethod
2445+
def get_atom_datetime64(cls, shape):
2446+
return _tables().Int64Col(shape=shape[0])
24152447

24162448
def set_atom_datetime64(self, block):
24172449
self.kind = "datetime64"
2418-
self.typ = self.get_atom_datetime64(block)
24192450
self.set_data(block.values)
24202451

24212452
def set_atom_datetime64tz(self, block):
@@ -2424,15 +2455,14 @@ def set_atom_datetime64tz(self, block):
24242455
self.tz = _get_tz(block.values.tz)
24252456

24262457
self.kind = "datetime64"
2427-
self.typ = self.get_atom_datetime64(block)
24282458
self.set_data(block.values)
24292459

2430-
def get_atom_timedelta64(self, block):
2431-
return _tables().Int64Col(shape=block.shape[0])
2460+
@classmethod
2461+
def get_atom_timedelta64(cls, shape):
2462+
return _tables().Int64Col(shape=shape[0])
24322463

24332464
def set_atom_timedelta64(self, block):
24342465
self.kind = "timedelta64"
2435-
self.typ = self.get_atom_timedelta64(block)
24362466
self.set_data(block.values)
24372467

24382468
@property
@@ -2558,16 +2588,20 @@ def validate_names(self):
25582588
# TODO: should the message here be more specifically non-str?
25592589
raise ValueError("cannot have non-object label DataIndexableCol")
25602590

2561-
def get_atom_string(self, shape, itemsize):
2591+
@classmethod
2592+
def get_atom_string(cls, shape, itemsize):
25622593
return _tables().StringCol(itemsize=itemsize)
25632594

2564-
def get_atom_data(self, shape, kind: str) -> "Col":
2565-
return self.get_atom_coltype(kind=kind)()
2595+
@classmethod
2596+
def get_atom_data(cls, shape, kind: str) -> "Col":
2597+
return cls.get_atom_coltype(kind=kind)()
25662598

2567-
def get_atom_datetime64(self, block):
2599+
@classmethod
2600+
def get_atom_datetime64(cls, shape):
25682601
return _tables().Int64Col()
25692602

2570-
def get_atom_timedelta64(self, block):
2603+
@classmethod
2604+
def get_atom_timedelta64(cls, shape):
25712605
return _tables().Int64Col()
25722606

25732607

@@ -3922,8 +3956,11 @@ def get_blk_items(mgr, blocks):
39223956
errors=self.errors,
39233957
)
39243958

3959+
typ = klass._get_atom(data_converted)
3960+
39253961
col = klass.create_for_block(i=i, name=new_name, version=self.version)
39263962
col.values = list(b_items)
3963+
col.typ = typ
39273964
col.set_atom(block=b, data_converted=data_converted, use_str=use_str)
39283965
col.update_info(self.info)
39293966
col.set_pos(j)

0 commit comments

Comments
 (0)