Skip to content

Commit 3e98b72

Browse files
committed
ENH: enable storage of hierarchical index objects in HDFStore. addresses GH pandas-dev#128
1 parent 95fb29b commit 3e98b72

File tree

5 files changed

+117
-30
lines changed

5 files changed

+117
-30
lines changed

pandas/core/daterange.py

+9
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ def __setstate__(self, aug_state):
114114
self.tzinfo = tzinfo
115115
Index.__setstate__(self, *index_state)
116116

117+
def equals(self, other):
118+
if self is other:
119+
return True
120+
121+
if not isinstance(other, Index):
122+
return False
123+
124+
return Index.equals(self.view(Index), other)
125+
117126
def is_all_dates(self):
118127
return True
119128

pandas/core/index.py

+3
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def equals(self, other):
156156
if not isinstance(other, Index):
157157
return False
158158

159+
if type(other) != Index:
160+
return other.equals(self)
161+
159162
return np.array_equal(self, other)
160163

161164
def asof(self, label):

pandas/io/pytables.py

+90-29
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def _read_block_manager(self, group):
359359

360360
axes = []
361361
for i in xrange(ndim):
362-
ax = _read_index(group, 'axis%d' % i)
362+
ax = self._read_index(group, 'axis%d' % i)
363363
axes.append(ax)
364364

365365
items = axes[0]
366366
blocks = []
367367
for i in range(group._v_attrs.nblocks):
368-
blk_items = _read_index(group, 'block%d_items' % i)
368+
blk_items = self._read_index(group, 'block%d_items' % i)
369369
values = _read_array(group, 'block%d_values' % i)
370370
blk = make_block(values, blk_items, items)
371371
blocks.append(blk)
@@ -410,9 +410,9 @@ def _write_long(self, group, panel, append=False):
410410
def _read_long(self, group, where=None):
411411
from pandas.core.index import MultiIndex
412412

413-
items = _read_index(group, 'items')
414-
major_axis = _read_index(group, 'major_axis')
415-
minor_axis = _read_index(group, 'minor_axis')
413+
items = self._read_index(group, 'items')
414+
major_axis = self._read_index(group, 'major_axis')
415+
minor_axis = self._read_index(group, 'minor_axis')
416416
major_labels = _read_array(group, 'major_labels')
417417
minor_labels = _read_array(group, 'minor_labels')
418418
values = _read_array(group, 'values')
@@ -421,12 +421,80 @@ def _read_long(self, group, where=None):
421421
labels=[major_labels, minor_labels])
422422
return LongPanel(values, index=index, columns=items)
423423

424-
def _write_index(self, group, key, value):
425-
# don't care about type here
426-
converted, kind, _ = _convert_index(value)
427-
self._write_array(group, key, converted)
428-
node = getattr(group, key)
429-
node._v_attrs.kind = kind
424+
def _write_index(self, group, key, index):
425+
if isinstance(index, MultiIndex):
426+
setattr(group._v_attrs, '%s_variety' % key, 'multi')
427+
self._write_multi_index(group, key, index)
428+
else:
429+
setattr(group._v_attrs, '%s_variety' % key, 'regular')
430+
converted, kind, _ = _convert_index(index)
431+
self._write_array(group, key, converted)
432+
node = getattr(group, key)
433+
node._v_attrs.kind = kind
434+
435+
def _read_index(self, group, key):
436+
try:
437+
variety = getattr(group._v_attrs, '%s_variety' % key)
438+
except Exception:
439+
variety = 'regular'
440+
441+
if variety == 'multi':
442+
return self._read_multi_index(group, key)
443+
elif variety == 'regular':
444+
_, index = self._read_index_node(getattr(group, key))
445+
return index
446+
else:
447+
raise Exception('unrecognized index variety')
448+
449+
def _write_multi_index(self, group, key, index):
450+
setattr(group._v_attrs, '%s_nlevels' % key, index.nlevels)
451+
452+
for i, (lev, lab, name) in enumerate(zip(index.levels,
453+
index.labels,
454+
index.names)):
455+
# write the level
456+
conv_level, kind, _ = _convert_index(lev)
457+
level_key = '%s_level%d' % (key, i)
458+
self._write_array(group, level_key, conv_level)
459+
node = getattr(group, level_key)
460+
node._v_attrs.kind = kind
461+
node._v_attrs.name = name
462+
463+
# write the name
464+
setattr(node._v_attrs, '%s_name%d' % (key, i), name)
465+
466+
# write the labels
467+
label_key = '%s_label%d' % (key, i)
468+
self._write_array(group, label_key, lab)
469+
470+
def _read_multi_index(self, group, key):
471+
nlevels = getattr(group._v_attrs, '%s_nlevels' % key)
472+
473+
levels = []
474+
labels = []
475+
names = []
476+
for i in range(nlevels):
477+
level_key = '%s_level%d' % (key, i)
478+
name, lev = self._read_index_node(getattr(group, level_key))
479+
levels.append(lev)
480+
names.append(name)
481+
482+
label_key = '%s_label%d' % (key, i)
483+
lab = getattr(group, label_key)[:]
484+
labels.append(lab)
485+
486+
return MultiIndex(levels=levels, labels=labels, names=names)
487+
488+
def _read_index_node(self, node):
489+
data = node[:]
490+
kind = node._v_attrs.kind
491+
492+
try:
493+
name = node._v_attrs.name
494+
except Exception:
495+
name = None
496+
497+
return name, _unconvert_index(data, kind)
430498

431499
def _write_array(self, group, key, value):
432500
if key in group:
@@ -531,21 +599,28 @@ def _read_group(self, group, where=None):
531599
return handler(group, where)
532600

533601
def _read_series(self, group, where=None):
534-
index = _read_index(group, 'index')
602+
index = self._read_index(group, 'index')
535603
values = _read_array(group, 'values')
536604
return Series(values, index=index)
537605

538606
def _read_legacy_series(self, group, where=None):
539-
index = _read_index_legacy(group, 'index')
607+
index = self._read_index_legacy(group, 'index')
540608
values = _read_array(group, 'values')
541609
return Series(values, index=index)
542610

543611
def _read_legacy_frame(self, group, where=None):
544-
index = _read_index_legacy(group, 'index')
545-
columns = _read_index_legacy(group, 'columns')
612+
index = self._read_index_legacy(group, 'index')
613+
columns = self._read_index_legacy(group, 'columns')
546614
values = _read_array(group, 'values')
547615
return DataFrame(values, index=index, columns=columns)
548616

617+
def _read_index_legacy(self, group, key):
618+
node = getattr(group, key)
619+
data = node[:]
620+
kind = node._v_attrs.kind
621+
622+
return _unconvert_index_legacy(data, kind)
623+
549624
def _read_frame_table(self, group, where=None):
550625
return self._read_panel_table(group, where)['value']
551626

@@ -618,13 +693,6 @@ def _read_array(group, key):
618693
else:
619694
return data
620695

621-
def _read_index(group, key):
622-
node = getattr(group, key)
623-
data = node[:]
624-
kind = node._v_attrs.kind
625-
626-
return _unconvert_index(data, kind)
627-
628696
def _unconvert_index(data, kind):
629697
if kind == 'datetime':
630698
index = np.array([datetime.fromtimestamp(v) for v in data],
@@ -635,13 +703,6 @@ def _unconvert_index(data, kind):
635703
raise ValueError('unrecognized index type %s' % kind)
636704
return index
637705

638-
def _read_index_legacy(group, key):
639-
node = getattr(group, key)
640-
data = node[:]
641-
kind = node._v_attrs.kind
642-
643-
return _unconvert_index_legacy(data, kind)
644-
645706
def _unconvert_index_legacy(data, kind, legacy=False):
646707
if kind == 'datetime':
647708
index = _tseries.array_to_datetime(data)

pandas/io/tests/test_pytables.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import numpy as np
77

8-
from pandas import (Series, DataFrame, Panel, LongPanel, DateRange)
8+
from pandas import (Series, DataFrame, Panel, LongPanel, DateRange,
9+
MultiIndex)
910
from pandas.io.pytables import HDFStore
1011
import pandas.util.testing as tm
1112

@@ -175,6 +176,18 @@ def test_frame(self):
175176
recons = self.store['df']
176177
self.assert_(recons._data.is_consolidated())
177178

179+
def test_store_hierarchical(self):
180+
index = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux'],
181+
['one', 'two', 'three']],
182+
labels=[[0, 0, 0, 1, 1, 2, 2, 3, 3, 3],
183+
[0, 1, 2, 0, 1, 1, 2, 0, 1, 2]])
184+
frame = DataFrame(np.random.randn(10, 3), index=index,
185+
columns=['A', 'B', 'C'])
186+
187+
self._check_roundtrip(frame, tm.assert_frame_equal)
188+
self._check_roundtrip(frame.T, tm.assert_frame_equal)
189+
self._check_roundtrip(frame['A'], tm.assert_series_equal)
190+
178191
def test_store_mixed(self):
179192
def _make_one():
180193
df = tm.makeDataFrame()

pandas/util/testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def assert_frame_equal(left, right):
119119
assert_series_equal(series, right[col])
120120
for col in right:
121121
assert(col in left)
122+
assert(left.index.equals(right.index))
122123
assert(left.columns.equals(right.columns))
123124

124125
def assert_panel_equal(left, right):

0 commit comments

Comments
 (0)