Skip to content

Commit 93158c5

Browse files
committed
Merge pull request #4470 from jreback/hdf_coord
ENH: allow where to be a list/array or a boolean mask of locations (GH4467)
2 parents 113e7d8 + 9a56c7e commit 93158c5

File tree

5 files changed

+93
-12
lines changed

5 files changed

+93
-12
lines changed

doc/source/io.rst

+16
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,22 @@ These do not currently accept the ``where`` selector (coming soon)
20692069
store.select_column('df_dc', 'index')
20702070
store.select_column('df_dc', 'string')
20712071
2072+
.. _io.hdf5-where_mask:
2073+
2074+
**Selecting using a where mask**
2075+
2076+
Sometime your query can involve creating a list of rows to select. Usually this ``mask`` would
2077+
be a resulting ``index`` from an indexing operation. This example selects the months of
2078+
a datetimeindex which are 5.
2079+
2080+
.. ipython:: python
2081+
2082+
df_mask = DataFrame(np.random.randn(1000,2),index=date_range('20000101',periods=1000))
2083+
store.append('df_mask',df_mask)
2084+
c = store.select_column('df_mask','index')
2085+
where = c[DatetimeIndex(c).month==5].index
2086+
store.select('df_mask',where=where)
2087+
20722088
**Replicating or**
20732089
20742090
``not`` and ``or`` conditions are unsupported at this time; however,

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pandas 0.13
8181
duplicate rows from a table (:issue:`4367`)
8282
- removed the ``warn`` argument from ``open``. Instead a ``PossibleDataLossError`` exception will
8383
be raised if you try to use ``mode='w'`` with an OPEN file handle (:issue:`4367`)
84+
- allow a passed locations array or mask as a ``where`` condition (:issue:`4467`)
8485

8586
**Experimental Features**
8687

doc/source/v0.13.0.txt

+7-5
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,19 @@ API changes
5959
store2.close()
6060
store2
6161

62+
- removed the ``_quiet`` attribute, replace by a ``DuplicateWarning`` if retrieving
63+
duplicate rows from a table (:issue:`4367`)
64+
- removed the ``warn`` argument from ``open``. Instead a ``PossibleDataLossError`` exception will
65+
be raised if you try to use ``mode='w'`` with an OPEN file handle (:issue:`4367`)
66+
- allow a passed locations array or mask as a ``where`` condition (:issue:`4467`).
67+
See :ref:`here<io.hdf5-where_mask>` for an example.
68+
6269
.. ipython:: python
6370
:suppress:
6471

6572
import os
6673
os.remove(path)
6774

68-
- removed the ``_quiet`` attribute, replace by a ``DuplicateWarning`` if retrieving
69-
duplicate rows from a table (:issue:`4367`)
70-
- removed the ``warn`` argument from ``open``. Instead a ``PossibleDataLossError`` exception will
71-
be raised if you try to use ``mode='w'`` with an OPEN file handle (:issue:`4367`)
72-
7375
Enhancements
7476
~~~~~~~~~~~~
7577

pandas/io/pytables.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ def append_to_multiple(self, d, value, selector, data_columns=None, axes=None, *
744744
dc = data_columns if k == selector else None
745745

746746
# compute the val
747-
val = value.reindex_axis(v, axis=axis, copy=False)
747+
val = value.reindex_axis(v, axis=axis)
748748

749749
self.append(k, val, data_columns=dc, **kwargs)
750750

@@ -2674,7 +2674,7 @@ def create_axes(self, axes, obj, validate=True, nan_rep=None, data_columns=None,
26742674

26752675
# reindex by our non_index_axes & compute data_columns
26762676
for a in self.non_index_axes:
2677-
obj = obj.reindex_axis(a[1], axis=a[0], copy=False)
2677+
obj = obj.reindex_axis(a[1], axis=a[0])
26782678

26792679
# figure out data_columns and get out blocks
26802680
block_obj = self.get_object(obj).consolidate()
@@ -2684,10 +2684,10 @@ def create_axes(self, axes, obj, validate=True, nan_rep=None, data_columns=None,
26842684
data_columns = self.validate_data_columns(data_columns, min_itemsize)
26852685
if len(data_columns):
26862686
blocks = block_obj.reindex_axis(Index(axis_labels) - Index(
2687-
data_columns), axis=axis, copy=False)._data.blocks
2687+
data_columns), axis=axis)._data.blocks
26882688
for c in data_columns:
26892689
blocks.extend(block_obj.reindex_axis(
2690-
[c], axis=axis, copy=False)._data.blocks)
2690+
[c], axis=axis)._data.blocks)
26912691

26922692
# reorder the blocks in the same order as the existing_table if we can
26932693
if existing_table is not None:
@@ -2760,7 +2760,7 @@ def process_axes(self, obj, columns=None):
27602760
for axis, labels in self.non_index_axes:
27612761
if columns is not None:
27622762
labels = Index(labels) & Index(columns)
2763-
obj = obj.reindex_axis(labels, axis=axis, copy=False)
2763+
obj = obj.reindex_axis(labels, axis=axis)
27642764

27652765
# apply the selection filters (but keep in the same order)
27662766
if self.selection.filter:
@@ -3765,9 +3765,34 @@ def __init__(self, table, where=None, start=None, stop=None, **kwargs):
37653765
self.terms = None
37663766
self.coordinates = None
37673767

3768+
# a coordinate
37683769
if isinstance(where, Coordinates):
37693770
self.coordinates = where.values
3770-
else:
3771+
3772+
elif com.is_list_like(where):
3773+
3774+
# see if we have a passed coordinate like
3775+
try:
3776+
inferred = lib.infer_dtype(where)
3777+
if inferred=='integer' or inferred=='boolean':
3778+
where = np.array(where)
3779+
if where.dtype == np.bool_:
3780+
start, stop = self.start, self.stop
3781+
if start is None:
3782+
start = 0
3783+
if stop is None:
3784+
stop = self.table.nrows
3785+
self.coordinates = np.arange(start,stop)[where]
3786+
elif issubclass(where.dtype.type,np.integer):
3787+
if (self.start is not None and (where<self.start).any()) or (self.stop is not None and (where>=self.stop).any()):
3788+
raise ValueError("where must have index locations >= start and < stop")
3789+
self.coordinates = where
3790+
3791+
except:
3792+
pass
3793+
3794+
if self.coordinates is None:
3795+
37713796
self.terms = self.generate(where)
37723797

37733798
# create the numexpr & the filter

pandas/io/tests/test_pytables.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pandas
1313
from pandas import (Series, DataFrame, Panel, MultiIndex, bdate_range,
14-
date_range, Index)
14+
date_range, Index, DatetimeIndex)
1515
from pandas.io.pytables import (HDFStore, get_store, Term, read_hdf,
1616
IncompatibilityWarning, PerformanceWarning,
1717
AttributeConflictWarning, DuplicateWarning,
@@ -2535,6 +2535,43 @@ def test_coordinates(self):
25352535
expected = expected[(expected.A > 0) & (expected.B > 0)]
25362536
tm.assert_frame_equal(result, expected)
25372537

2538+
# pass array/mask as the coordinates
2539+
with ensure_clean(self.path) as store:
2540+
2541+
df = DataFrame(np.random.randn(1000,2),index=date_range('20000101',periods=1000))
2542+
store.append('df',df)
2543+
c = store.select_column('df','index')
2544+
where = c[DatetimeIndex(c).month==5].index
2545+
expected = df.iloc[where]
2546+
2547+
# locations
2548+
result = store.select('df',where=where)
2549+
tm.assert_frame_equal(result,expected)
2550+
2551+
# boolean
2552+
result = store.select('df',where=where)
2553+
tm.assert_frame_equal(result,expected)
2554+
2555+
# invalid
2556+
self.assertRaises(ValueError, store.select, 'df',where=np.arange(len(df),dtype='float64'))
2557+
self.assertRaises(ValueError, store.select, 'df',where=np.arange(len(df)+1))
2558+
self.assertRaises(ValueError, store.select, 'df',where=np.arange(len(df)),start=5)
2559+
self.assertRaises(ValueError, store.select, 'df',where=np.arange(len(df)),start=5,stop=10)
2560+
2561+
# list
2562+
df = DataFrame(np.random.randn(10,2))
2563+
store.append('df2',df)
2564+
result = store.select('df2',where=[0,3,5])
2565+
expected = df.iloc[[0,3,5]]
2566+
tm.assert_frame_equal(result,expected)
2567+
2568+
# boolean
2569+
where = [True] * 10
2570+
where[-2] = False
2571+
result = store.select('df2',where=where)
2572+
expected = df.loc[where]
2573+
tm.assert_frame_equal(result,expected)
2574+
25382575
def test_append_to_multiple(self):
25392576
df1 = tm.makeTimeDataFrame()
25402577
df2 = tm.makeTimeDataFrame().rename(columns=lambda x: "%s_2" % x)

0 commit comments

Comments
 (0)