diff --git a/pandas/compat/chainmap.py b/pandas/compat/chainmap.py index 84824207de2a9..493ba68bffd4a 100644 --- a/pandas/compat/chainmap.py +++ b/pandas/compat/chainmap.py @@ -15,3 +15,8 @@ def __delitem__(self, key): del mapping[key] return raise KeyError(key) + + def new_child(self, *args, **kwargs) -> "DeepChainMap": + # ChainMap.new_child returns self.__class__(...) but mypy + # doesn't know that, so we annotate it explicitly here. + return super().new_child(*args, **kwargs) # type: ignore diff --git a/pandas/core/computation/eval.py b/pandas/core/computation/eval.py index de2133f64291d..72f2e1d8e23e5 100644 --- a/pandas/core/computation/eval.py +++ b/pandas/core/computation/eval.py @@ -11,7 +11,7 @@ from pandas.core.computation.engines import _engines from pandas.core.computation.expr import Expr, _parsers, tokenize_string -from pandas.core.computation.scope import _ensure_scope +from pandas.core.computation.scope import ensure_scope from pandas.io.formats.printing import pprint_thing @@ -309,7 +309,7 @@ def eval( _check_for_locals(expr, level, parser) # get our (possibly passed-in) scope - env = _ensure_scope( + env = ensure_scope( level + 1, global_dict=global_dict, local_dict=local_dict, diff --git a/pandas/core/computation/pytables.py b/pandas/core/computation/pytables.py index 13a4814068d6a..13133d04ccd5d 100644 --- a/pandas/core/computation/pytables.py +++ b/pandas/core/computation/pytables.py @@ -13,7 +13,7 @@ import pandas as pd import pandas.core.common as com -from pandas.core.computation import expr, ops +from pandas.core.computation import expr, ops, scope as _scope from pandas.core.computation.common import _ensure_decoded from pandas.core.computation.expr import BaseExprVisitor from pandas.core.computation.ops import UndefinedVariableError, is_term @@ -21,7 +21,7 @@ from pandas.io.formats.printing import pprint_thing, pprint_thing_encoded -class Scope(expr.Scope): +class Scope(_scope.Scope): __slots__ = ("queryables",) def __init__(self, level, global_dict=None, local_dict=None, queryables=None): diff --git a/pandas/core/computation/scope.py b/pandas/core/computation/scope.py index ee82664f6cb21..9b9dde83f8584 100644 --- a/pandas/core/computation/scope.py +++ b/pandas/core/computation/scope.py @@ -1,7 +1,6 @@ """ Module for scope operations """ - import datetime import inspect from io import StringIO @@ -9,6 +8,7 @@ import pprint import struct import sys +from typing import Any, List, MutableMapping import numpy as np @@ -16,9 +16,9 @@ from pandas.compat.chainmap import DeepChainMap -def _ensure_scope( - level, global_dict=None, local_dict=None, resolvers=(), target=None, **kwargs -): +def ensure_scope( + level: int, global_dict=None, local_dict=None, resolvers=(), target=None, **kwargs +) -> "Scope": """Ensure that we are grabbing the correct scope.""" return Scope( level + 1, @@ -104,22 +104,28 @@ class Scope: """ __slots__ = ["level", "scope", "target", "resolvers", "temps"] + level: int + scope: DeepChainMap + resolvers: DeepChainMap + temps: MutableMapping[str, Any] def __init__( - self, level, global_dict=None, local_dict=None, resolvers=(), target=None + self, level: int, global_dict=None, local_dict=None, resolvers=(), target=None, ): self.level = level + 1 # shallow copy because we don't want to keep filling this up with what - # was there before if there are multiple calls to Scope/_ensure_scope + # was there before if there are multiple calls to Scope/ensure_scope self.scope = DeepChainMap(_DEFAULT_GLOBALS.copy()) self.target = target + assert all(isinstance(x, str) for x in self.scope), self.scope + if isinstance(local_dict, Scope): self.scope.update(local_dict.scope) if local_dict.target is not None: self.target = local_dict.target - self.update(local_dict.level) + self._update(local_dict.level) frame = sys._getframe(self.level) @@ -161,7 +167,7 @@ def has_resolvers(self) -> bool: """ return bool(len(self.resolvers)) - def resolve(self, key, is_local): + def resolve(self, key: str, is_local: bool): """ Resolve a variable name in a possibly local context. @@ -203,7 +209,7 @@ def resolve(self, key, is_local): raise UndefinedVariableError(key, is_local) - def swapkey(self, old_key, new_key, new_value=None): + def swapkey(self, old_key: str, new_key: str, new_value=None): """ Replace a variable name, with a potentially new value. @@ -216,10 +222,14 @@ def swapkey(self, old_key, new_key, new_value=None): new_value : object Value to be replaced along with the possible renaming """ + maps: List[MutableMapping] + mapping: MutableMapping + + # TODO: convince mypy that these maps are in fact mutable if self.has_resolvers: - maps = self.resolvers.maps + self.scope.maps + maps = self.resolvers.maps + self.scope.maps # type: ignore else: - maps = self.scope.maps + maps = self.scope.maps # type: ignore maps.append(self.temps) @@ -228,7 +238,7 @@ def swapkey(self, old_key, new_key, new_value=None): mapping[new_key] = new_value return - def _get_vars(self, stack, scopes): + def _get_vars(self, stack, scopes: List[str]): """ Get specifically scoped variables from a list of stack frames. @@ -241,9 +251,9 @@ def _get_vars(self, stack, scopes): evaluate to a dictionary. For example, ('locals', 'globals') """ variables = itertools.product(scopes, stack) - for scope, (frame, _, _, _, _, _) in variables: + for name, (frame, _, _, _, _, _) in variables: try: - d = getattr(frame, "f_" + scope) + d = getattr(frame, "f_" + name) self.scope = self.scope.new_child(d) finally: # won't remove it, but DECREF it @@ -251,7 +261,7 @@ def _get_vars(self, stack, scopes): # scope after the loop del frame - def update(self, level: int): + def _update(self, level: int): """ Update the current scope by going back `level` levels. @@ -303,7 +313,7 @@ def ntemps(self) -> int: return len(self.temps) @property - def full_scope(self): + def full_scope(self) -> DeepChainMap: """ Return the full scope for use with passing to engines transparently as a mapping. @@ -313,5 +323,6 @@ def full_scope(self): vars : DeepChainMap All variables in this scope. """ - maps = [self.temps] + self.resolvers.maps + self.scope.maps + # TODO: convince mypy that all of the maps are mutable + maps = [self.temps] + self.resolvers.maps + self.scope.maps # type: ignore return DeepChainMap(*maps)