Skip to content

Commit 7379833

Browse files
committed
Merge pull request #5566 from jreback/hdf_select
TST/API/BUG: resolve scoping issues in pytables query where rhs is a compound selection or scoped variable
2 parents e82c1f7 + 8bdf093 commit 7379833

File tree

4 files changed

+110
-19
lines changed

4 files changed

+110
-19
lines changed

pandas/computation/expr.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def __init__(self, gbls=None, lcls=None, level=1, resolvers=None,
125125
self.globals['True'] = True
126126
self.globals['False'] = False
127127

128+
# function defs
129+
self.globals['list'] = list
130+
self.globals['tuple'] = tuple
131+
128132
res_keys = (list(o.keys()) for o in self.resolvers)
129133
self.resolver_keys = frozenset(reduce(operator.add, res_keys, []))
130134
self._global_resolvers = self.resolvers + (self.locals, self.globals)
@@ -505,21 +509,21 @@ def _possibly_evaluate_binop(self, op, op_class, lhs, rhs,
505509
maybe_eval_in_python=('==', '!=')):
506510
res = op(lhs, rhs)
507511

508-
if (res.op in _cmp_ops_syms and lhs.is_datetime or rhs.is_datetime and
509-
self.engine != 'pytables'):
510-
# all date ops must be done in python bc numexpr doesn't work well
511-
# with NaT
512-
return self._possibly_eval(res, self.binary_ops)
512+
if self.engine != 'pytables':
513+
if (res.op in _cmp_ops_syms and getattr(lhs,'is_datetime',False) or getattr(rhs,'is_datetime',False)):
514+
# all date ops must be done in python bc numexpr doesn't work well
515+
# with NaT
516+
return self._possibly_eval(res, self.binary_ops)
513517

514518
if res.op in eval_in_python:
515519
# "in"/"not in" ops are always evaluated in python
516520
return self._possibly_eval(res, eval_in_python)
517-
elif (lhs.return_type == object or rhs.return_type == object and
518-
self.engine != 'pytables'):
519-
# evaluate "==" and "!=" in python if either of our operands has an
520-
# object return type
521-
return self._possibly_eval(res, eval_in_python +
522-
maybe_eval_in_python)
521+
elif self.engine != 'pytables':
522+
if (getattr(lhs,'return_type',None) == object or getattr(rhs,'return_type',None) == object):
523+
# evaluate "==" and "!=" in python if either of our operands has an
524+
# object return type
525+
return self._possibly_eval(res, eval_in_python +
526+
maybe_eval_in_python)
523527
return res
524528

525529
def visit_BinOp(self, node, **kwargs):
@@ -635,7 +639,7 @@ def visit_Attribute(self, node, **kwargs):
635639

636640
raise ValueError("Invalid Attribute context {0}".format(ctx.__name__))
637641

638-
def visit_Call(self, node, **kwargs):
642+
def visit_Call(self, node, side=None, **kwargs):
639643

640644
# this can happen with: datetime.datetime
641645
if isinstance(node.func, ast.Attribute):

pandas/computation/pytables.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,15 @@ def visit_Assign(self, node, **kwargs):
401401
return self.visit(cmpr)
402402

403403
def visit_Subscript(self, node, **kwargs):
404+
# only allow simple suscripts
405+
404406
value = self.visit(node.value)
405407
slobj = self.visit(node.slice)
408+
try:
409+
value = value.value
410+
except:
411+
pass
412+
406413
try:
407414
return self.const_type(value[slobj], self.env)
408415
except TypeError:
@@ -416,9 +423,16 @@ def visit_Attribute(self, node, **kwargs):
416423
ctx = node.ctx.__class__
417424
if ctx == ast.Load:
418425
# resolve the value
419-
resolved = self.visit(value).value
426+
resolved = self.visit(value)
427+
428+
# try to get the value to see if we are another expression
429+
try:
430+
resolved = resolved.value
431+
except (AttributeError):
432+
pass
433+
420434
try:
421-
return getattr(resolved, attr)
435+
return self.term_type(getattr(resolved, attr), self.env)
422436
except AttributeError:
423437

424438
# something like datetime.datetime where scope is overriden

pandas/io/pytables.py

+4
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ def read_hdf(path_or_buf, key, **kwargs):
294294
295295
"""
296296

297+
# grab the scope
298+
if 'where' in kwargs:
299+
kwargs['where'] = _ensure_term(kwargs['where'])
300+
297301
f = lambda store, auto_close: store.select(
298302
key, auto_close=auto_close, **kwargs)
299303

pandas/io/tests/test_pytables.py

+74-5
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,19 @@ def ensure_clean_store(path, mode='a', complevel=None, complib=None,
8181
def ensure_clean_path(path):
8282
"""
8383
return essentially a named temporary file that is not opened
84-
and deleted on existing
84+
and deleted on existing; if path is a list, then create and
85+
return list of filenames
8586
"""
86-
8787
try:
88-
filename = create_tempfile(path)
89-
yield filename
88+
if isinstance(path, list):
89+
filenames = [ create_tempfile(p) for p in path ]
90+
yield filenames
91+
else:
92+
filenames = [ create_tempfile(path) ]
93+
yield filenames[0]
9094
finally:
91-
safe_remove(filename)
95+
for f in filenames:
96+
safe_remove(f)
9297

9398
# set these parameters so we don't have file sharing
9499
tables.parameters.MAX_NUMEXPR_THREADS = 1
@@ -3124,6 +3129,70 @@ def test_frame_select_complex(self):
31243129
expected = df.loc[df.index>df.index[3]].reindex(columns=['A','B'])
31253130
tm.assert_frame_equal(result, expected)
31263131

3132+
def test_frame_select_complex2(self):
3133+
3134+
with ensure_clean_path(['parms.hdf','hist.hdf']) as paths:
3135+
3136+
pp, hh = paths
3137+
3138+
# use non-trivial selection criteria
3139+
parms = DataFrame({ 'A' : [1,1,2,2,3] })
3140+
parms.to_hdf(pp,'df',mode='w',format='table',data_columns=['A'])
3141+
3142+
selection = read_hdf(pp,'df',where='A=[2,3]')
3143+
hist = DataFrame(np.random.randn(25,1),columns=['data'],
3144+
index=MultiIndex.from_tuples([ (i,j) for i in range(5) for j in range(5) ],
3145+
names=['l1','l2']))
3146+
3147+
hist.to_hdf(hh,'df',mode='w',format='table')
3148+
3149+
expected = read_hdf(hh,'df',where=Term('l1','=',[2,3,4]))
3150+
3151+
# list like
3152+
result = read_hdf(hh,'df',where=Term('l1','=',selection.index.tolist()))
3153+
assert_frame_equal(result, expected)
3154+
l = selection.index.tolist()
3155+
3156+
# sccope with list like
3157+
store = HDFStore(hh)
3158+
result = store.select('df',where='l1=l')
3159+
assert_frame_equal(result, expected)
3160+
store.close()
3161+
3162+
result = read_hdf(hh,'df',where='l1=l')
3163+
assert_frame_equal(result, expected)
3164+
3165+
# index
3166+
index = selection.index
3167+
result = read_hdf(hh,'df',where='l1=index')
3168+
assert_frame_equal(result, expected)
3169+
3170+
result = read_hdf(hh,'df',where='l1=selection.index')
3171+
assert_frame_equal(result, expected)
3172+
3173+
result = read_hdf(hh,'df',where='l1=selection.index.tolist()')
3174+
assert_frame_equal(result, expected)
3175+
3176+
result = read_hdf(hh,'df',where='l1=list(selection.index)')
3177+
assert_frame_equal(result, expected)
3178+
3179+
# sccope with index
3180+
store = HDFStore(hh)
3181+
3182+
result = store.select('df',where='l1=index')
3183+
assert_frame_equal(result, expected)
3184+
3185+
result = store.select('df',where='l1=selection.index')
3186+
assert_frame_equal(result, expected)
3187+
3188+
result = store.select('df',where='l1=selection.index.tolist()')
3189+
assert_frame_equal(result, expected)
3190+
3191+
result = store.select('df',where='l1=list(selection.index)')
3192+
assert_frame_equal(result, expected)
3193+
3194+
store.close()
3195+
31273196
def test_invalid_filtering(self):
31283197

31293198
# can't use more than one filter (atm)

0 commit comments

Comments
 (0)