Skip to content

Commit 81cfe3c

Browse files
Simplify Model __new__ and metaclass (#7473)
* Type get_context correctly get_context returns an instance of a Model, not a ContextMeta object We don't need the typevar, since we don't use it for anything special * Import from future to use delayed evaluation of annotations All of these are supported on python>=3.9. * New ModelManager class for managing model contexts We create a global instance of it within this module, which is similar to how it worked before, where a `context_class` attribute was attached to the Model class. We inherit from threading.local to ensure thread safety when working with models on multiple threads. See #1552 for the reasoning. This is already tested in `test_thread_safety`. * Model class is now the context manager directly * Fix type of UNSET in type definition UNSET is the instance of the _UnsetType type. We should be typing the latter here. * Set model parent in init rather than in __new__ We use the new ModelManager.parent_context property to reliably set any parent context, or else set it to None. * Replace get_context in metaclass with classmethod We set this directly on the class as a classmethod, which is clearer than going via the metaclass. * Remove get_contexts from metaclass The original function does not behave as I expected. In the following example I expected that it would return only the final model, not root. This method is not used anywhere in the pymc codebase, so I have dropped it from the codebase. I originally included the following code to replace it, but since it is not used anyway, it is better to remove it. ```python` @classmethod def get_contexts(cls) -> list[Model]: """Return a list of the currently active model contexts.""" return MODEL_MANAGER.active_contexts ``` Example for testing behaviour in current main branch: ```python import pymc as pm with pm.Model(name="root") as root: print([c.name for c in pm.Model.get_contexts()]) with pm.Model(name="first") as first: print([c.name for c in pm.Model.get_contexts()]) with pm.Model(name="m_with_model_None", model=None) as m_with_model_None: # This one doesn't make much sense: print([c.name for c in pm.Model.get_contexts()]) ``` * Simplify ContextMeta We only keep the __call__ method, which is necessary to keep the model context itself active during that model's __init__. * Type Model.register_rv for for downstream typing In pymc/distributions/distribution.py, this change allows the type checker to infer that `rv_out` can only be a TensorVariable. Thanks to @ricardoV94 for type hint on rv_var. * Include np.ndarray as possible type for coord values I originally tried numpy's ArrayLike, replacing Sequence entirely, but then I realized that ArrayLike also allows non-sequences like integers and floats. I am not certain if `values="a string"` should be legal. With the type hint sequence, it is. Might be more accurate, but verbose to use `list | tuple | set | np.ndarray | None`. * Use function-scoped new_dims to handle type hint varying throughout function We don't want to allow the user to pass a `dims=[None, None]` to our function, but current behaviour set `dims=[None] * N` at the end of `determine_coords`. To handle this, I created a `new_dims` with a larger type scope which matches the return type of `dims` in `determine_coords`. Then I did the same within def Data to support this new type hint. * Fix case of dims = [None, None, ...] The only case where dims=[None, ...] is when the user has passed dims=None. Since the user passed dims=None, they shouldn't be expecting any coords to match that dimension. Thus we don't need to try to add any more coords to the model. * Remove unused hack
1 parent e25a042 commit 81cfe3c

File tree

2 files changed

+74
-146
lines changed

2 files changed

+74
-146
lines changed

pymc/data.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
221221
def determine_coords(
222222
model,
223223
value: pd.DataFrame | pd.Series | xr.DataArray,
224-
dims: Sequence[str | None] | None = None,
224+
dims: Sequence[str] | None = None,
225225
coords: dict[str, Sequence | np.ndarray] | None = None,
226-
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]:
226+
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
227227
"""Determine coordinate values from data or the model (via ``dims``)."""
228228
if coords is None:
229229
coords = {}
@@ -268,9 +268,10 @@ def determine_coords(
268268

269269
if dims is None:
270270
# TODO: Also determine dim names from the index
271-
dims = [None] * np.ndim(value)
272-
273-
return coords, dims
271+
new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value)
272+
else:
273+
new_dims = dims
274+
return coords, new_dims
274275

275276

276277
def ConstantData(
@@ -366,7 +367,7 @@ def Data(
366367
The name for this variable.
367368
value : array_like or pandas.Series, pandas.Dataframe
368369
A value to associate with this variable.
369-
dims : str or tuple of str, optional
370+
dims : str, tuple of str or tuple of None, optional
370371
Dimension names of the random variables (as opposed to the shapes of these
371372
random variables). Use this when ``value`` is a pandas Series or DataFrame. The
372373
``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ
@@ -451,14 +452,17 @@ def Data(
451452
expected=x.ndim,
452453
)
453454

455+
new_dims: Sequence[str] | Sequence[None] | None
454456
if infer_dims_and_coords:
455-
coords, dims = determine_coords(model, value, dims)
457+
coords, new_dims = determine_coords(model, value, dims)
458+
else:
459+
new_dims = dims
456460

457-
if dims:
461+
if new_dims:
458462
xshape = x.shape
459463
# Register new dimension lengths
460-
for d, dname in enumerate(dims):
461-
if dname not in model.dim_lengths:
464+
for d, dname in enumerate(new_dims):
465+
if dname not in model.dim_lengths and dname is not None:
462466
model.add_coord(
463467
name=dname,
464468
# Note: Coordinate values can't be taken from
@@ -467,6 +471,6 @@ def Data(
467471
length=xshape[d],
468472
)
469473

470-
model.register_data_var(x, dims=dims)
474+
model.register_data_var(x, dims=new_dims)
471475

472476
return x

pymc/model/core.py

Lines changed: 59 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import functools
1617
import sys
@@ -19,13 +20,8 @@
1920
import warnings
2021

2122
from collections.abc import Iterable, Sequence
22-
from sys import modules
2323
from typing import (
24-
TYPE_CHECKING,
2524
Literal,
26-
Optional,
27-
TypeVar,
28-
Union,
2925
cast,
3026
overload,
3127
)
@@ -42,7 +38,6 @@
4238
from pytensor.tensor.random.op import RandomVariable
4339
from pytensor.tensor.random.type import RandomType
4440
from pytensor.tensor.variable import TensorConstant, TensorVariable
45-
from typing_extensions import Self
4641

4742
from pymc.blocking import DictToArrayBijection, RaveledVars
4843
from pymc.data import is_valid_observed
@@ -73,6 +68,7 @@
7368
VarName,
7469
WithMemoization,
7570
_add_future_warning_tag,
71+
_UnsetType,
7672
get_transformed_name,
7773
get_value_vars_from_user_vars,
7874
get_var_name,
@@ -92,118 +88,36 @@
9288
]
9389

9490

95-
T = TypeVar("T", bound="ContextMeta")
91+
class ModelManager(threading.local):
92+
"""Keeps track of currently active model contexts.
9693
94+
A global instance of this is created in this module on import.
95+
Use that instance, `MODEL_MANAGER` to inspect current contexts.
9796
98-
class ContextMeta(type):
99-
"""Functionality for objects that put themselves in a context manager."""
100-
101-
def __new__(cls, name, bases, dct, **kwargs):
102-
"""Add __enter__ and __exit__ methods to the class."""
103-
104-
def __enter__(self):
105-
self.__class__.context_class.get_contexts().append(self)
106-
return self
107-
108-
def __exit__(self, typ, value, traceback):
109-
self.__class__.context_class.get_contexts().pop()
97+
It inherits from threading.local so is thread-safe, if models
98+
can be entered/exited within individual threads.
99+
"""
110100

111-
dct[__enter__.__name__] = __enter__
112-
dct[__exit__.__name__] = __exit__
101+
def __init__(self):
102+
self.active_contexts: list[Model] = []
113103

114-
# We strip off keyword args, per the warning from
115-
# StackExchange:
116-
# DO NOT send "**kwargs" to "type.__new__". It won't catch them and
117-
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
118-
return super().__new__(cls, name, bases, dct)
104+
@property
105+
def current_context(self) -> Model | None:
106+
"""Return the innermost context of any current contexts."""
107+
return self.active_contexts[-1] if self.active_contexts else None
119108

120-
# FIXME: is there a more elegant way to automatically add methods to the class that
121-
# are instance methods instead of class methods?
122-
def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs):
123-
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
124-
if context_class is not None:
125-
cls._context_class = context_class
126-
super().__init__(name, bases, nmspc)
109+
@property
110+
def parent_context(self) -> Model | None:
111+
"""Return the parent context to the active context, if any."""
112+
return self.active_contexts[-2] if len(self.active_contexts) > 1 else None
127113

128-
def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None:
129-
"""Return the most recently pushed context object of type ``cls`` on the stack, or ``None``.
130114

131-
If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``.
132-
"""
133-
try:
134-
candidate: T | None = cls.get_contexts()[-1]
135-
except IndexError:
136-
# Calling code expects to get a TypeError if the entity
137-
# is unfound, and there's too much to fix.
138-
if error_if_none:
139-
raise TypeError(f"No {cls} on context stack")
140-
return None
141-
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access:
142-
raise BlockModelAccessError(candidate.error_msg_on_access)
143-
return candidate
144-
145-
def get_contexts(cls) -> list[T]:
146-
"""Return a stack of context instances for the ``context_class`` of ``cls``."""
147-
# This lazily creates the context class's contexts
148-
# thread-local object, as needed. This seems inelegant to me,
149-
# but since the context class is not guaranteed to exist when
150-
# the metaclass is being instantiated, I couldn't figure out a
151-
# better way. [2019/10/11:rpg]
152-
153-
# no race-condition here, contexts is a thread-local object
154-
# be sure not to override contexts in a subclass however!
155-
context_class = cls.context_class
156-
assert isinstance(
157-
context_class, type
158-
), f"Name of context class, {context_class} was not resolvable to a class"
159-
if not hasattr(context_class, "contexts"):
160-
context_class.contexts = threading.local()
161-
162-
contexts = context_class.contexts
163-
164-
if not hasattr(contexts, "stack"):
165-
contexts.stack = []
166-
return contexts.stack
167-
168-
# the following complex property accessor is necessary because the
169-
# context_class may not have been created at the point it is
170-
# specified, so the context_class may be a class *name* rather
171-
# than a class.
172-
@property
173-
def context_class(cls) -> type:
174-
def resolve_type(c: type | str) -> type:
175-
if isinstance(c, str):
176-
c = getattr(modules[cls.__module__], c)
177-
if isinstance(c, type):
178-
return c
179-
raise ValueError(f"Cannot resolve context class {c}")
180-
181-
assert cls is not None
182-
if isinstance(cls._context_class, str):
183-
cls._context_class = resolve_type(cls._context_class)
184-
if not isinstance(cls._context_class, str | type):
185-
raise ValueError(
186-
f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type"
187-
)
188-
return cls._context_class
189-
190-
# Inherit context class from parent
191-
def __init_subclass__(cls, **kwargs):
192-
super().__init_subclass__(**kwargs)
193-
cls.context_class = super().context_class
194-
195-
# Initialize object in its own context...
196-
# Merged from InitContextMeta in the original.
197-
def __call__(cls, *args, **kwargs):
198-
# We type hint Model here so type checkers understand that Model is a context manager.
199-
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info.
200-
instance: Model = cls.__new__(cls, *args, **kwargs)
201-
with instance: # appends context
202-
instance.__init__(*args, **kwargs)
203-
return instance
115+
# MODEL_MANAGER is instantiated at import, and serves as a truth for
116+
# what any currently active model contexts are.
117+
MODEL_MANAGER = ModelManager()
204118

205119

206-
def modelcontext(model: Optional["Model"]) -> "Model":
120+
def modelcontext(model: Model | None) -> Model:
207121
"""Return the given model or, if None was supplied, try to find one in the context stack."""
208122
if model is None:
209123
model = Model.get_context(error_if_none=False)
@@ -372,6 +286,18 @@ def profile(self):
372286
return self._pytensor_function.profile
373287

374288

289+
class ContextMeta(type):
290+
"""A metaclass in order to apply a model's context during `Model.__init__``."""
291+
292+
# We want the Model's context to be active during __init__. In order for this
293+
# to apply to subclasses of Model as well, we need to use a metaclass.
294+
def __call__(cls: type[Model], *args, **kwargs):
295+
instance = cls.__new__(cls, *args, **kwargs)
296+
with instance: # applies context
297+
instance.__init__(*args, **kwargs)
298+
return instance
299+
300+
375301
class Model(WithMemoization, metaclass=ContextMeta):
376302
"""Encapsulates the variables and likelihood factors of a model.
377303
@@ -495,22 +421,14 @@ class Model(WithMemoization, metaclass=ContextMeta):
495421
496422
"""
497423

498-
if TYPE_CHECKING:
424+
def __enter__(self):
425+
"""Enter the context manager."""
426+
MODEL_MANAGER.active_contexts.append(self)
427+
return self
499428

500-
def __enter__(self: Self) -> Self:
501-
"""Enter the context manager."""
502-
503-
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
504-
"""Exit the context manager."""
505-
506-
def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs):
507-
# resolves the parent instance
508-
instance = super().__new__(cls)
509-
if model is UNSET:
510-
instance._parent = cls.get_context(error_if_none=False)
511-
else:
512-
instance._parent = model
513-
return instance
429+
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
430+
"""Exit the context manager."""
431+
_ = MODEL_MANAGER.active_contexts.pop()
514432

515433
@staticmethod
516434
def _validate_name(name):
@@ -525,11 +443,11 @@ def __init__(
525443
check_bounds=True,
526444
*,
527445
coords_mutable=None,
528-
model: Union[Literal[UNSET], None, "Model"] = UNSET,
446+
model: _UnsetType | None | Model = UNSET,
529447
):
530-
del model # used in __new__ to define the parent of this model
531448
self.name = self._validate_name(name)
532449
self.check_bounds = check_bounds
450+
self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context
533451

534452
if coords_mutable is not None:
535453
warnings.warn(
@@ -577,6 +495,17 @@ def __init__(
577495
functools.partial(str_for_model, formatting="latex"), self
578496
)
579497

498+
@classmethod
499+
def get_context(
500+
cls, error_if_none: bool = True, allow_block_model_access: bool = False
501+
) -> Model | None:
502+
model = MODEL_MANAGER.current_context
503+
if isinstance(model, BlockModelAccess) and not allow_block_model_access:
504+
raise BlockModelAccessError(model.error_msg_on_access)
505+
if model is None and error_if_none:
506+
raise TypeError("No model on context stack")
507+
return model
508+
580509
@property
581510
def parent(self):
582511
return self._parent
@@ -967,7 +896,7 @@ def shape_from_dims(self, dims):
967896
def add_coord(
968897
self,
969898
name: str,
970-
values: Sequence | None = None,
899+
values: Sequence | np.ndarray | None = None,
971900
mutable: bool | None = None,
972901
*,
973902
length: int | Variable | None = None,
@@ -1233,16 +1162,16 @@ def set_data(
12331162

12341163
def register_rv(
12351164
self,
1236-
rv_var,
1237-
name,
1165+
rv_var: RandomVariable,
1166+
name: str,
12381167
*,
12391168
observed=None,
12401169
total_size=None,
12411170
dims=None,
12421171
default_transform=UNSET,
12431172
transform=UNSET,
12441173
initval=None,
1245-
):
1174+
) -> TensorVariable:
12461175
"""Register an (un)observed random variable with the model.
12471176
12481177
Parameters
@@ -2074,11 +2003,6 @@ def to_graphviz(
20742003
)
20752004

20762005

2077-
# this is really disgusting, but it breaks a self-loop: I can't pass Model
2078-
# itself as context class init arg.
2079-
Model._context_class = Model
2080-
2081-
20822006
class BlockModelAccess(Model):
20832007
"""Can be used to prevent user access to Model contexts."""
20842008

0 commit comments

Comments
 (0)