From 97a2c8451c53e8c2147e97120485e321c2c088aa Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 5 Dec 2019 17:21:41 -0800 Subject: [PATCH 1/4] REF: implement io.pytables.DataCol._get_atom --- pandas/io/pytables.py | 59 ++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 66b8089537e8d..80723e8a08cac 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -2353,13 +2353,43 @@ def set_atom(self, block, data_converted, use_str: bool): # set as a data block self.set_atom_data(block) + def _get_atom(self, values: Union[np.ndarray, ABCExtensionArray]) -> "Col": + """ + Get an appropriately typed and shaped pytables.Col object for values. + """ + + dtype = values.dtype + itemsize = dtype.itemsize + + shape = values.shape + if values.ndim == 1: + # EA, use block shape pretending it is 2D + shape = (1, values.size) + + if is_categorical_dtype(dtype): + codes = values.codes + atom = self.get_atom_data(shape, kind=codes.dtype.name) + elif dtype.kind == "M": + atom = self.get_atom_datetime64(shape) + elif dtype.kind == "m": + atom = self.get_atom_timedelta64(shape) + elif dtype.kind == "c": + atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) + + elif dtype.kind == "S": + atom = self.get_atom_string(shape, itemsize) + + else: + atom = self.get_atom_data(shape, kind=dtype.name) + + return atom + def get_atom_string(self, shape, itemsize): return _tables().StringCol(itemsize=itemsize, shape=shape[0]) def set_atom_string(self, data_converted: np.ndarray): - itemsize = data_converted.dtype.itemsize self.kind = "string" - self.typ = self.get_atom_string(data_converted.shape, itemsize) + self.typ = self._get_atom(data_converted) self.set_data(data_converted) def get_atom_coltype(self, kind: str) -> Type["Col"]: @@ -2378,13 +2408,12 @@ def get_atom_data(self, shape, kind: str) -> "Col": def set_atom_complex(self, block): self.kind = block.dtype.name - itemsize = int(self.kind.split("complex")[-1]) // 8 - self.typ = _tables().ComplexCol(itemsize=itemsize, shape=block.shape[0]) + self.typ = self._get_atom(block.values) self.set_data(block.values) def set_atom_data(self, block): self.kind = block.dtype.name - self.typ = self.get_atom_data(block.shape, kind=block.dtype.name) + self.typ = self._get_atom(block.values) self.set_data(block.values) def set_atom_categorical(self, block): @@ -2401,7 +2430,7 @@ def set_atom_categorical(self, block): # write the codes; must be in a block shape self.ordered = values.ordered - self.typ = self.get_atom_data(block.shape, kind=codes.dtype.name) + self.typ = self._get_atom(block.values) self.set_data(block.values) # write the categories @@ -2410,12 +2439,12 @@ def set_atom_categorical(self, block): assert self.kind == "integer", self.kind assert self.dtype == codes.dtype.name, codes.dtype.name - def get_atom_datetime64(self, block): - return _tables().Int64Col(shape=block.shape[0]) + def get_atom_datetime64(self, shape): + return _tables().Int64Col(shape=shape[0]) def set_atom_datetime64(self, block): self.kind = "datetime64" - self.typ = self.get_atom_datetime64(block) + self.typ = self._get_atom(block.values) self.set_data(block.values) def set_atom_datetime64tz(self, block): @@ -2424,15 +2453,15 @@ def set_atom_datetime64tz(self, block): self.tz = _get_tz(block.values.tz) self.kind = "datetime64" - self.typ = self.get_atom_datetime64(block) + self.typ = self._get_atom(block.values) self.set_data(block.values) - def get_atom_timedelta64(self, block): - return _tables().Int64Col(shape=block.shape[0]) + def get_atom_timedelta64(self, shape): + return _tables().Int64Col(shape=shape[0]) def set_atom_timedelta64(self, block): self.kind = "timedelta64" - self.typ = self.get_atom_timedelta64(block) + self.typ = self._get_atom(block.values) self.set_data(block.values) @property @@ -2564,10 +2593,10 @@ def get_atom_string(self, shape, itemsize): def get_atom_data(self, shape, kind: str) -> "Col": return self.get_atom_coltype(kind=kind)() - def get_atom_datetime64(self, block): + def get_atom_datetime64(self, shape): return _tables().Int64Col() - def get_atom_timedelta64(self, block): + def get_atom_timedelta64(self, shape): return _tables().Int64Col() From 4eda7bd47fa52ad50558dd53d2dc5bf9f5fcefcf Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 5 Dec 2019 17:25:11 -0800 Subject: [PATCH 2/4] REF: collect _get_atom calls in one place --- pandas/io/pytables.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 80723e8a08cac..fe13562623664 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -2389,7 +2389,6 @@ def get_atom_string(self, shape, itemsize): def set_atom_string(self, data_converted: np.ndarray): self.kind = "string" - self.typ = self._get_atom(data_converted) self.set_data(data_converted) def get_atom_coltype(self, kind: str) -> Type["Col"]: @@ -2408,12 +2407,10 @@ def get_atom_data(self, shape, kind: str) -> "Col": def set_atom_complex(self, block): self.kind = block.dtype.name - self.typ = self._get_atom(block.values) self.set_data(block.values) def set_atom_data(self, block): self.kind = block.dtype.name - self.typ = self._get_atom(block.values) self.set_data(block.values) def set_atom_categorical(self, block): @@ -2430,7 +2427,6 @@ def set_atom_categorical(self, block): # write the codes; must be in a block shape self.ordered = values.ordered - self.typ = self._get_atom(block.values) self.set_data(block.values) # write the categories @@ -2444,7 +2440,6 @@ def get_atom_datetime64(self, shape): def set_atom_datetime64(self, block): self.kind = "datetime64" - self.typ = self._get_atom(block.values) self.set_data(block.values) def set_atom_datetime64tz(self, block): @@ -2453,7 +2448,6 @@ def set_atom_datetime64tz(self, block): self.tz = _get_tz(block.values.tz) self.kind = "datetime64" - self.typ = self._get_atom(block.values) self.set_data(block.values) def get_atom_timedelta64(self, shape): @@ -2461,7 +2455,6 @@ def get_atom_timedelta64(self, shape): def set_atom_timedelta64(self, block): self.kind = "timedelta64" - self.typ = self._get_atom(block.values) self.set_data(block.values) @property @@ -3953,6 +3946,7 @@ def get_blk_items(mgr, blocks): col = klass.create_for_block(i=i, name=new_name, version=self.version) col.values = list(b_items) + col.typ = col._get_atom(data_converted) col.set_atom(block=b, data_converted=data_converted, use_str=use_str) col.update_info(self.info) col.set_pos(j) From 06c880c9237ba26162bb9e0e099bdc187b29e5fe Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 5 Dec 2019 17:36:49 -0800 Subject: [PATCH 3/4] REF: make the relevant methods into classmethods --- pandas/io/pytables.py | 48 +++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index fe13562623664..7cef9a17d3de5 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -2353,7 +2353,8 @@ def set_atom(self, block, data_converted, use_str: bool): # set as a data block self.set_atom_data(block) - def _get_atom(self, values: Union[np.ndarray, ABCExtensionArray]) -> "Col": + @classmethod + def _get_atom(cls, values: Union[np.ndarray, ABCExtensionArray]) -> "Col": """ Get an appropriately typed and shaped pytables.Col object for values. """ @@ -2368,30 +2369,32 @@ def _get_atom(self, values: Union[np.ndarray, ABCExtensionArray]) -> "Col": if is_categorical_dtype(dtype): codes = values.codes - atom = self.get_atom_data(shape, kind=codes.dtype.name) + atom = cls.get_atom_data(shape, kind=codes.dtype.name) elif dtype.kind == "M": - atom = self.get_atom_datetime64(shape) + atom = cls.get_atom_datetime64(shape) elif dtype.kind == "m": - atom = self.get_atom_timedelta64(shape) + atom = cls.get_atom_timedelta64(shape) elif dtype.kind == "c": atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) elif dtype.kind == "S": - atom = self.get_atom_string(shape, itemsize) + atom = cls.get_atom_string(shape, itemsize) else: - atom = self.get_atom_data(shape, kind=dtype.name) + atom = cls.get_atom_data(shape, kind=dtype.name) return atom - def get_atom_string(self, shape, itemsize): + @classmethod + def get_atom_string(cls, shape, itemsize): return _tables().StringCol(itemsize=itemsize, shape=shape[0]) def set_atom_string(self, data_converted: np.ndarray): self.kind = "string" self.set_data(data_converted) - def get_atom_coltype(self, kind: str) -> Type["Col"]: + @classmethod + def get_atom_coltype(cls, kind: str) -> Type["Col"]: """ return the PyTables column class for this column """ if kind.startswith("uint"): k4 = kind[4:] @@ -2402,8 +2405,9 @@ def get_atom_coltype(self, kind: str) -> Type["Col"]: return getattr(_tables(), col_name) - def get_atom_data(self, shape, kind: str) -> "Col": - return self.get_atom_coltype(kind=kind)(shape=shape[0]) + @classmethod + def get_atom_data(cls, shape, kind: str) -> "Col": + return cls.get_atom_coltype(kind=kind)(shape=shape[0]) def set_atom_complex(self, block): self.kind = block.dtype.name @@ -2435,7 +2439,8 @@ def set_atom_categorical(self, block): assert self.kind == "integer", self.kind assert self.dtype == codes.dtype.name, codes.dtype.name - def get_atom_datetime64(self, shape): + @classmethod + def get_atom_datetime64(cls, shape): return _tables().Int64Col(shape=shape[0]) def set_atom_datetime64(self, block): @@ -2450,7 +2455,8 @@ def set_atom_datetime64tz(self, block): self.kind = "datetime64" self.set_data(block.values) - def get_atom_timedelta64(self, shape): + @classmethod + def get_atom_timedelta64(cls, shape): return _tables().Int64Col(shape=shape[0]) def set_atom_timedelta64(self, block): @@ -2580,16 +2586,20 @@ def validate_names(self): # TODO: should the message here be more specifically non-str? raise ValueError("cannot have non-object label DataIndexableCol") - def get_atom_string(self, shape, itemsize): + @classmethod + def get_atom_string(cls, shape, itemsize): return _tables().StringCol(itemsize=itemsize) - def get_atom_data(self, shape, kind: str) -> "Col": - return self.get_atom_coltype(kind=kind)() + @classmethod + def get_atom_data(cls, shape, kind: str) -> "Col": + return cls.get_atom_coltype(kind=kind)() - def get_atom_datetime64(self, shape): + @classmethod + def get_atom_datetime64(cls, shape): return _tables().Int64Col() - def get_atom_timedelta64(self, shape): + @classmethod + def get_atom_timedelta64(cls, shape): return _tables().Int64Col() @@ -3944,9 +3954,11 @@ def get_blk_items(mgr, blocks): errors=self.errors, ) + typ = klass._get_atom(data_converted) + col = klass.create_for_block(i=i, name=new_name, version=self.version) col.values = list(b_items) - col.typ = col._get_atom(data_converted) + col.typ = typ col.set_atom(block=b, data_converted=data_converted, use_str=use_str) col.update_info(self.info) col.set_pos(j) From 8b0f79a8e6d74c4ef298642fa597741008c96685 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 5 Dec 2019 18:33:36 -0800 Subject: [PATCH 4/4] use is_foo_dtype checks --- pandas/io/pytables.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 7cef9a17d3de5..347d60f8a7354 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -34,10 +34,12 @@ from pandas.core.dtypes.common import ( ensure_object, is_categorical_dtype, + is_complex_dtype, is_datetime64_dtype, is_datetime64tz_dtype, is_extension_array_dtype, is_list_like, + is_string_dtype, is_timedelta64_dtype, ) from pandas.core.dtypes.generic import ABCExtensionArray @@ -2370,14 +2372,14 @@ def _get_atom(cls, values: Union[np.ndarray, ABCExtensionArray]) -> "Col": if is_categorical_dtype(dtype): codes = values.codes atom = cls.get_atom_data(shape, kind=codes.dtype.name) - elif dtype.kind == "M": + elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype): atom = cls.get_atom_datetime64(shape) - elif dtype.kind == "m": + elif is_timedelta64_dtype(dtype): atom = cls.get_atom_timedelta64(shape) - elif dtype.kind == "c": + elif is_complex_dtype(dtype): atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) - elif dtype.kind == "S": + elif is_string_dtype(dtype): atom = cls.get_atom_string(shape, itemsize) else: