diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 66b8089537e8d..347d60f8a7354 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -34,10 +34,12 @@ from pandas.core.dtypes.common import ( ensure_object, is_categorical_dtype, + is_complex_dtype, is_datetime64_dtype, is_datetime64tz_dtype, is_extension_array_dtype, is_list_like, + is_string_dtype, is_timedelta64_dtype, ) from pandas.core.dtypes.generic import ABCExtensionArray @@ -2353,16 +2355,48 @@ def set_atom(self, block, data_converted, use_str: bool): # set as a data block self.set_atom_data(block) - def get_atom_string(self, shape, itemsize): + @classmethod + def _get_atom(cls, values: Union[np.ndarray, ABCExtensionArray]) -> "Col": + """ + Get an appropriately typed and shaped pytables.Col object for values. + """ + + dtype = values.dtype + itemsize = dtype.itemsize + + shape = values.shape + if values.ndim == 1: + # EA, use block shape pretending it is 2D + shape = (1, values.size) + + if is_categorical_dtype(dtype): + codes = values.codes + atom = cls.get_atom_data(shape, kind=codes.dtype.name) + elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype): + atom = cls.get_atom_datetime64(shape) + elif is_timedelta64_dtype(dtype): + atom = cls.get_atom_timedelta64(shape) + elif is_complex_dtype(dtype): + atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) + + elif is_string_dtype(dtype): + atom = cls.get_atom_string(shape, itemsize) + + else: + atom = cls.get_atom_data(shape, kind=dtype.name) + + return atom + + @classmethod + def get_atom_string(cls, shape, itemsize): return _tables().StringCol(itemsize=itemsize, shape=shape[0]) def set_atom_string(self, data_converted: np.ndarray): - itemsize = data_converted.dtype.itemsize self.kind = "string" - self.typ = self.get_atom_string(data_converted.shape, itemsize) self.set_data(data_converted) - def get_atom_coltype(self, kind: str) -> Type["Col"]: + @classmethod + def get_atom_coltype(cls, kind: str) -> Type["Col"]: """ return the PyTables column class for this column """ if kind.startswith("uint"): k4 = kind[4:] @@ -2373,18 +2407,16 @@ def get_atom_coltype(self, kind: str) -> Type["Col"]: return getattr(_tables(), col_name) - def get_atom_data(self, shape, kind: str) -> "Col": - return self.get_atom_coltype(kind=kind)(shape=shape[0]) + @classmethod + def get_atom_data(cls, shape, kind: str) -> "Col": + return cls.get_atom_coltype(kind=kind)(shape=shape[0]) def set_atom_complex(self, block): self.kind = block.dtype.name - itemsize = int(self.kind.split("complex")[-1]) // 8 - self.typ = _tables().ComplexCol(itemsize=itemsize, shape=block.shape[0]) self.set_data(block.values) def set_atom_data(self, block): self.kind = block.dtype.name - self.typ = self.get_atom_data(block.shape, kind=block.dtype.name) self.set_data(block.values) def set_atom_categorical(self, block): @@ -2401,7 +2433,6 @@ def set_atom_categorical(self, block): # write the codes; must be in a block shape self.ordered = values.ordered - self.typ = self.get_atom_data(block.shape, kind=codes.dtype.name) self.set_data(block.values) # write the categories @@ -2410,12 +2441,12 @@ def set_atom_categorical(self, block): assert self.kind == "integer", self.kind assert self.dtype == codes.dtype.name, codes.dtype.name - def get_atom_datetime64(self, block): - return _tables().Int64Col(shape=block.shape[0]) + @classmethod + def get_atom_datetime64(cls, shape): + return _tables().Int64Col(shape=shape[0]) def set_atom_datetime64(self, block): self.kind = "datetime64" - self.typ = self.get_atom_datetime64(block) self.set_data(block.values) def set_atom_datetime64tz(self, block): @@ -2424,15 +2455,14 @@ def set_atom_datetime64tz(self, block): self.tz = _get_tz(block.values.tz) self.kind = "datetime64" - self.typ = self.get_atom_datetime64(block) self.set_data(block.values) - def get_atom_timedelta64(self, block): - return _tables().Int64Col(shape=block.shape[0]) + @classmethod + def get_atom_timedelta64(cls, shape): + return _tables().Int64Col(shape=shape[0]) def set_atom_timedelta64(self, block): self.kind = "timedelta64" - self.typ = self.get_atom_timedelta64(block) self.set_data(block.values) @property @@ -2558,16 +2588,20 @@ def validate_names(self): # TODO: should the message here be more specifically non-str? raise ValueError("cannot have non-object label DataIndexableCol") - def get_atom_string(self, shape, itemsize): + @classmethod + def get_atom_string(cls, shape, itemsize): return _tables().StringCol(itemsize=itemsize) - def get_atom_data(self, shape, kind: str) -> "Col": - return self.get_atom_coltype(kind=kind)() + @classmethod + def get_atom_data(cls, shape, kind: str) -> "Col": + return cls.get_atom_coltype(kind=kind)() - def get_atom_datetime64(self, block): + @classmethod + def get_atom_datetime64(cls, shape): return _tables().Int64Col() - def get_atom_timedelta64(self, block): + @classmethod + def get_atom_timedelta64(cls, shape): return _tables().Int64Col() @@ -3922,8 +3956,11 @@ def get_blk_items(mgr, blocks): errors=self.errors, ) + typ = klass._get_atom(data_converted) + col = klass.create_for_block(i=i, name=new_name, version=self.version) col.values = list(b_items) + col.typ = typ col.set_atom(block=b, data_converted=data_converted, use_str=use_str) col.update_info(self.info) col.set_pos(j)