Skip to content

Commit d1bc773

Browse files
jbrockmendeljreback
authored andcommitted
REF: use new pytables helper methods to de-duplicate _convert_index (#30144)
1 parent 71cdd50 commit d1bc773

File tree

1 file changed

+51
-64
lines changed

1 file changed

+51
-64
lines changed

pandas/io/pytables.py

+51-64
Original file line numberDiff line numberDiff line change
@@ -2256,16 +2256,7 @@ def set_data(self, data: Union[np.ndarray, ABCExtensionArray]):
22562256
assert data is not None
22572257
assert self.dtype is None
22582258

2259-
if is_categorical_dtype(data.dtype):
2260-
data = data.codes
2261-
2262-
# For datetime64tz we need to drop the TZ in tests TODO: why?
2263-
dtype_name = data.dtype.name.split("[")[0]
2264-
2265-
if data.dtype.kind in ["m", "M"]:
2266-
data = np.asarray(data.view("i8"))
2267-
# TODO: we used to reshape for the dt64tz case, but no longer
2268-
# doing that doesnt seem to break anything. why?
2259+
data, dtype_name = _get_data_and_dtype_name(data)
22692260

22702261
self.data = data
22712262
self.dtype = dtype_name
@@ -2318,6 +2309,9 @@ def get_atom_coltype(cls, kind: str) -> Type["Col"]:
23182309
if kind.startswith("uint"):
23192310
k4 = kind[4:]
23202311
col_name = f"UInt{k4}Col"
2312+
elif kind.startswith("period"):
2313+
# we store as integer
2314+
col_name = "Int64Col"
23212315
else:
23222316
kcap = kind.capitalize()
23232317
col_name = f"{kcap}Col"
@@ -4638,37 +4632,21 @@ def _convert_index(name: str, index: Index, encoding=None, errors="strict"):
46384632
assert isinstance(name, str)
46394633

46404634
index_name = index.name
4641-
4642-
if isinstance(index, DatetimeIndex):
4643-
converted = index.asi8
4644-
return IndexCol(
4645-
name,
4646-
converted,
4647-
"datetime64",
4648-
_tables().Int64Col(),
4649-
freq=index.freq,
4650-
tz=index.tz,
4651-
index_name=index_name,
4652-
)
4653-
elif isinstance(index, TimedeltaIndex):
4654-
converted = index.asi8
4635+
converted, dtype_name = _get_data_and_dtype_name(index)
4636+
kind = _dtype_to_kind(dtype_name)
4637+
atom = DataIndexableCol._get_atom(converted)
4638+
4639+
if isinstance(index, Int64Index):
4640+
# Includes Int64Index, RangeIndex, DatetimeIndex, TimedeltaIndex, PeriodIndex,
4641+
# in which case "kind" is "integer", "integer", "datetime64",
4642+
# "timedelta64", and "integer", respectively.
46554643
return IndexCol(
46564644
name,
4657-
converted,
4658-
"timedelta64",
4659-
_tables().Int64Col(),
4660-
freq=index.freq,
4661-
index_name=index_name,
4662-
)
4663-
elif isinstance(index, (Int64Index, PeriodIndex)):
4664-
atom = _tables().Int64Col()
4665-
# avoid to store ndarray of Period objects
4666-
return IndexCol(
4667-
name,
4668-
index._ndarray_values,
4669-
"integer",
4670-
atom,
4645+
values=converted,
4646+
kind=kind,
4647+
typ=atom,
46714648
freq=getattr(index, "freq", None),
4649+
tz=getattr(index, "tz", None),
46724650
index_name=index_name,
46734651
)
46744652

@@ -4687,8 +4665,6 @@ def _convert_index(name: str, index: Index, encoding=None, errors="strict"):
46874665
name, converted, "date", _tables().Time32Col(), index_name=index_name,
46884666
)
46894667
elif inferred_type == "string":
4690-
# atom = _tables().ObjectAtom()
4691-
# return np.asarray(values, dtype='O'), 'object', atom
46924668

46934669
converted = _convert_string_array(values, encoding, errors)
46944670
itemsize = converted.dtype.itemsize
@@ -4700,30 +4676,15 @@ def _convert_index(name: str, index: Index, encoding=None, errors="strict"):
47004676
index_name=index_name,
47014677
)
47024678

4703-
elif inferred_type == "integer":
4704-
# take a guess for now, hope the values fit
4705-
atom = _tables().Int64Col()
4679+
elif inferred_type in ["integer", "floating"]:
47064680
return IndexCol(
4707-
name,
4708-
np.asarray(values, dtype=np.int64),
4709-
"integer",
4710-
atom,
4711-
index_name=index_name,
4712-
)
4713-
elif inferred_type == "floating":
4714-
atom = _tables().Float64Col()
4715-
return IndexCol(
4716-
name,
4717-
np.asarray(values, dtype=np.float64),
4718-
"float",
4719-
atom,
4720-
index_name=index_name,
4681+
name, values=converted, kind=kind, typ=atom, index_name=index_name,
47214682
)
47224683
else:
4684+
assert isinstance(converted, np.ndarray) and converted.dtype == object
4685+
assert kind == "object", kind
47234686
atom = _tables().ObjectAtom()
4724-
return IndexCol(
4725-
name, np.asarray(values, dtype="O"), "object", atom, index_name=index_name,
4726-
)
4687+
return IndexCol(name, converted, kind, atom, index_name=index_name,)
47274688

47284689

47294690
def _unconvert_index(data, kind: str, encoding=None, errors="strict"):
@@ -4950,21 +4911,47 @@ def _dtype_to_kind(dtype_str: str) -> str:
49504911
kind = "complex"
49514912
elif dtype_str.startswith("int") or dtype_str.startswith("uint"):
49524913
kind = "integer"
4953-
elif dtype_str.startswith("date"):
4954-
# in tests this is always "datetime64"
4955-
kind = "datetime"
4914+
elif dtype_str.startswith("datetime64"):
4915+
kind = "datetime64"
49564916
elif dtype_str.startswith("timedelta"):
4957-
kind = "timedelta"
4917+
kind = "timedelta64"
49584918
elif dtype_str.startswith("bool"):
49594919
kind = "bool"
49604920
elif dtype_str.startswith("category"):
49614921
kind = "category"
4922+
elif dtype_str.startswith("period"):
4923+
# We store the `freq` attr so we can restore from integers
4924+
kind = "integer"
4925+
elif dtype_str == "object":
4926+
kind = "object"
49624927
else:
49634928
raise ValueError(f"cannot interpret dtype of [{dtype_str}]")
49644929

49654930
return kind
49664931

49674932

4933+
def _get_data_and_dtype_name(data: Union[np.ndarray, ABCExtensionArray]):
4934+
"""
4935+
Convert the passed data into a storable form and a dtype string.
4936+
"""
4937+
if is_categorical_dtype(data.dtype):
4938+
data = data.codes
4939+
4940+
# For datetime64tz we need to drop the TZ in tests TODO: why?
4941+
dtype_name = data.dtype.name.split("[")[0]
4942+
4943+
if data.dtype.kind in ["m", "M"]:
4944+
data = np.asarray(data.view("i8"))
4945+
# TODO: we used to reshape for the dt64tz case, but no longer
4946+
# doing that doesnt seem to break anything. why?
4947+
4948+
elif isinstance(data, PeriodIndex):
4949+
data = data.asi8
4950+
4951+
data = np.asarray(data)
4952+
return data, dtype_name
4953+
4954+
49684955
class Selection:
49694956
"""
49704957
Carries out a selection operation on a tables.Table object.

0 commit comments

Comments
 (0)