Skip to content

Commit f0e8cce

Browse files
jbrockmendelproost
authored andcommitted
CLN: annotate and de-nest write_array (pandas-dev#30012)
1 parent 93b9158 commit f0e8cce

File tree

1 file changed

+44
-46
lines changed

1 file changed

+44
-46
lines changed

pandas/io/pytables.py

+44-46
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
is_list_like,
4141
is_timedelta64_dtype,
4242
)
43+
from pandas.core.dtypes.generic import ABCExtensionArray
4344
from pandas.core.dtypes.missing import array_equivalent
4445

4546
from pandas import (
@@ -54,7 +55,7 @@
5455
concat,
5556
isna,
5657
)
57-
from pandas._typing import FrameOrSeries
58+
from pandas._typing import ArrayLike, FrameOrSeries
5859
from pandas.core.arrays.categorical import Categorical
5960
import pandas.core.common as com
6061
from pandas.core.computation.pytables import PyTablesExpr, maybe_expression
@@ -2960,7 +2961,7 @@ def read_index_node(
29602961
data = node[start:stop]
29612962
# If the index was an empty array write_array_empty() will
29622963
# have written a sentinel. Here we relace it with the original.
2963-
if "shape" in node._v_attrs and self._is_empty_array(node._v_attrs.shape):
2964+
if "shape" in node._v_attrs and np.prod(node._v_attrs.shape) == 0:
29642965
data = np.empty(node._v_attrs.shape, dtype=node._v_attrs.value_type,)
29652966
kind = _ensure_decoded(node._v_attrs.kind)
29662967
name = None
@@ -3006,25 +3007,27 @@ def read_index_node(
30063007

30073008
return index
30083009

3009-
def write_array_empty(self, key: str, value):
3010+
def write_array_empty(self, key: str, value: ArrayLike):
30103011
""" write a 0-len array """
30113012

30123013
# ugly hack for length 0 axes
30133014
arr = np.empty((1,) * value.ndim)
30143015
self._handle.create_array(self.group, key, arr)
3015-
getattr(self.group, key)._v_attrs.value_type = str(value.dtype)
3016-
getattr(self.group, key)._v_attrs.shape = value.shape
3016+
node = getattr(self.group, key)
3017+
node._v_attrs.value_type = str(value.dtype)
3018+
node._v_attrs.shape = value.shape
30173019

3018-
def _is_empty_array(self, shape) -> bool:
3019-
"""Returns true if any axis is zero length."""
3020-
return any(x == 0 for x in shape)
3020+
def write_array(self, key: str, value: ArrayLike, items: Optional[Index] = None):
3021+
# TODO: we only have one test that gets here, the only EA
3022+
# that gets passed is DatetimeArray, and we never have
3023+
# both self._filters and EA
3024+
assert isinstance(value, (np.ndarray, ABCExtensionArray)), type(value)
30213025

3022-
def write_array(self, key: str, value, items=None):
30233026
if key in self.group:
30243027
self._handle.remove_node(self.group, key)
30253028

30263029
# Transform needed to interface with pytables row/col notation
3027-
empty_array = self._is_empty_array(value.shape)
3030+
empty_array = value.size == 0
30283031
transposed = False
30293032

30303033
if is_categorical_dtype(value):
@@ -3039,29 +3042,29 @@ def write_array(self, key: str, value, items=None):
30393042
value = value.T
30403043
transposed = True
30413044

3045+
atom = None
30423046
if self._filters is not None:
3043-
atom = None
30443047
try:
30453048
# get the atom for this datatype
30463049
atom = _tables().Atom.from_dtype(value.dtype)
30473050
except ValueError:
30483051
pass
30493052

3050-
if atom is not None:
3051-
# create an empty chunked array and fill it from value
3052-
if not empty_array:
3053-
ca = self._handle.create_carray(
3054-
self.group, key, atom, value.shape, filters=self._filters
3055-
)
3056-
ca[:] = value
3057-
getattr(self.group, key)._v_attrs.transposed = transposed
3053+
if atom is not None:
3054+
# We only get here if self._filters is non-None and
3055+
# the Atom.from_dtype call succeeded
30583056

3059-
else:
3060-
self.write_array_empty(key, value)
3057+
# create an empty chunked array and fill it from value
3058+
if not empty_array:
3059+
ca = self._handle.create_carray(
3060+
self.group, key, atom, value.shape, filters=self._filters
3061+
)
3062+
ca[:] = value
30613063

3062-
return
3064+
else:
3065+
self.write_array_empty(key, value)
30633066

3064-
if value.dtype.type == np.object_:
3067+
elif value.dtype.type == np.object_:
30653068

30663069
# infer the type, warn if we have a non-string type here (for
30673070
# performance)
@@ -3071,35 +3074,30 @@ def write_array(self, key: str, value, items=None):
30713074
elif inferred_type == "string":
30723075
pass
30733076
else:
3074-
try:
3075-
items = list(items)
3076-
except TypeError:
3077-
pass
30783077
ws = performance_doc % (inferred_type, key, items)
30793078
warnings.warn(ws, PerformanceWarning, stacklevel=7)
30803079

30813080
vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom())
30823081
vlarr.append(value)
3082+
3083+
elif empty_array:
3084+
self.write_array_empty(key, value)
3085+
elif is_datetime64_dtype(value.dtype):
3086+
self._handle.create_array(self.group, key, value.view("i8"))
3087+
getattr(self.group, key)._v_attrs.value_type = "datetime64"
3088+
elif is_datetime64tz_dtype(value.dtype):
3089+
# store as UTC
3090+
# with a zone
3091+
self._handle.create_array(self.group, key, value.asi8)
3092+
3093+
node = getattr(self.group, key)
3094+
node._v_attrs.tz = _get_tz(value.tz)
3095+
node._v_attrs.value_type = "datetime64"
3096+
elif is_timedelta64_dtype(value.dtype):
3097+
self._handle.create_array(self.group, key, value.view("i8"))
3098+
getattr(self.group, key)._v_attrs.value_type = "timedelta64"
30833099
else:
3084-
if empty_array:
3085-
self.write_array_empty(key, value)
3086-
else:
3087-
if is_datetime64_dtype(value.dtype):
3088-
self._handle.create_array(self.group, key, value.view("i8"))
3089-
getattr(self.group, key)._v_attrs.value_type = "datetime64"
3090-
elif is_datetime64tz_dtype(value.dtype):
3091-
# store as UTC
3092-
# with a zone
3093-
self._handle.create_array(self.group, key, value.asi8)
3094-
3095-
node = getattr(self.group, key)
3096-
node._v_attrs.tz = _get_tz(value.tz)
3097-
node._v_attrs.value_type = "datetime64"
3098-
elif is_timedelta64_dtype(value.dtype):
3099-
self._handle.create_array(self.group, key, value.view("i8"))
3100-
getattr(self.group, key)._v_attrs.value_type = "timedelta64"
3101-
else:
3102-
self._handle.create_array(self.group, key, value)
3100+
self._handle.create_array(self.group, key, value)
31033101

31043102
getattr(self.group, key)._v_attrs.transposed = transposed
31053103

0 commit comments

Comments
 (0)