Skip to content

Commit 372720e

Browse files
Fix typing in multiple places
1 parent b58fc7a commit 372720e

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

pymc/printing.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from functools import partial
17+
1618
from pytensor.compile import SharedVariable
1719
from pytensor.graph.basic import Constant, walk
1820
from pytensor.tensor.basic import TensorVariable, Variable
@@ -55,7 +57,7 @@ def str_for_dist(
5557

5658
if "latex" in formatting:
5759
if print_name is not None:
58-
print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}"
60+
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
5961

6062
op_name = (
6163
dist.owner.op._print_name[1]
@@ -96,17 +98,16 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
9698
"""Make a human-readable string representation of Model, listing all random variables
9799
and their distributions, optionally including parameter values."""
98100

99-
kwargs = dict(formatting=formatting, include_params=include_params)
100-
free_rv_reprs = [str_for_dist(dist, **kwargs) for dist in model.free_RVs]
101-
observed_rv_reprs = [str_for_dist(rv, **kwargs) for rv in model.observed_RVs]
102-
det_reprs = [
103-
str_for_potential_or_deterministic(dist, **kwargs, dist_name="Deterministic")
104-
for dist in model.deterministics
105-
]
106-
potential_reprs = [
107-
str_for_potential_or_deterministic(pot, **kwargs, dist_name="Potential")
108-
for pot in model.potentials
109-
]
101+
# Wrap functions to avoid confusing typecheckers
102+
sfd = partial(str_for_dist, formatting=formatting, include_params=include_params)
103+
sfp = partial(
104+
str_for_potential_or_deterministic, formatting=formatting, include_params=include_params
105+
)
106+
107+
free_rv_reprs = [sfd(dist) for dist in model.free_RVs]
108+
observed_rv_reprs = [sfd(rv) for rv in model.observed_RVs]
109+
det_reprs = [sfp(dist, dist_name="Deterministic") for dist in model.deterministics]
110+
potential_reprs = [sfp(pot, dist_name="Potential") for pot in model.potentials]
110111

111112
var_reprs = free_rv_reprs + det_reprs + observed_rv_reprs + potential_reprs
112113

@@ -162,6 +163,8 @@ def _str_for_input_var(var: Variable, formatting: str) -> str:
162163
from pymc.distributions.distribution import SymbolicRandomVariable
163164

164165
def _is_potential_or_deterministic(var: Variable) -> bool:
166+
if not hasattr(var, "str_repr"):
167+
return False
165168
try:
166169
return var.str_repr.__func__.func is str_for_potential_or_deterministic
167170
except AttributeError:
@@ -175,14 +178,15 @@ def _is_potential_or_deterministic(var: Variable) -> bool:
175178
) or _is_potential_or_deterministic(var):
176179
# show the names for RandomVariables, Deterministics, and Potentials, rather
177180
# than the full expression
181+
assert isinstance(var, TensorVariable)
178182
return _str_for_input_rv(var, formatting)
179183
elif isinstance(var.owner.op, DimShuffle):
180184
return _str_for_input_var(var.owner.inputs[0], formatting)
181185
else:
182186
return _str_for_expression(var, formatting)
183187

184188

185-
def _str_for_input_rv(var: Variable, formatting: str) -> str:
189+
def _str_for_input_rv(var: TensorVariable, formatting: str) -> str:
186190
_str = (
187191
var.name
188192
if var.name is not None
@@ -221,12 +225,15 @@ def _expand(x):
221225
if x.owner and (not isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)):
222226
return reversed(x.owner.inputs)
223227

224-
parents = [
225-
x
226-
for x in walk(nodes=var.owner.inputs, expand=_expand)
227-
if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)
228-
]
229-
names = [x.name for x in parents]
228+
parents = []
229+
names = []
230+
for x in walk(nodes=var.owner.inputs, expand=_expand):
231+
assert isinstance(x, Variable)
232+
if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable):
233+
parents.append(x)
234+
xname = x.name
235+
assert xname is not None
236+
names.append(xname)
230237

231238
if "latex" in formatting:
232239
return (
@@ -257,6 +264,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
257264
"""Handy plug-in method to instruct IPython-like REPLs to use our str_repr above."""
258265
# we know that our str_repr does not recurse, so we can ignore cycle
259266
try:
267+
if not hasattr(obj, "str_repr"):
268+
raise AttributeError
260269
output = obj.str_repr()
261270
# Find newlines and replace them with p.break_()
262271
# (see IPython.lib.pretty._repr_pprint)

pymc/pytensorf.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import warnings
1515

1616
from collections.abc import Callable, Generator, Iterable, Sequence
17+
from typing import cast
1718

1819
import numpy as np
1920
import pandas as pd
@@ -29,7 +30,6 @@
2930
from pytensor.graph.basic import (
3031
Apply,
3132
Constant,
32-
Node,
3333
Variable,
3434
clone_get_equiv,
3535
graph_inputs,
@@ -208,8 +208,8 @@ def replace_vars_in_graphs(
208208
"""
209209
# Clone graphs and get equivalences
210210
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
211-
equiv = {k: k for k in replacements.keys()}
212-
equiv = clone_get_equiv(inputs, graphs, False, False, equiv)
211+
memo = {k: k for k in replacements.keys()}
212+
equiv = clone_get_equiv(inputs, graphs, False, False, memo)
213213

214214
fg = FunctionGraph(
215215
[equiv[i] for i in inputs],
@@ -753,7 +753,7 @@ def find_rng_nodes(
753753
]
754754

755755

756-
def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]:
756+
def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> list[TensorVariable]:
757757
"""Replace any RNG nodes upstream of outputs by new RNGs of the same type
758758
759759
This can be used when combining a pre-existing graph with a cloned one, to ensure
@@ -775,7 +775,7 @@ def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVaria
775775
rng_cls = np.random.Generator
776776
new_rng_nodes.append(pytensor.shared(rng_cls(np.random.PCG64())))
777777
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
778-
return graph.outputs
778+
return cast(list[TensorVariable], graph.outputs)
779779

780780

781781
SeedSequenceSeed = None | int | Sequence[int] | np.ndarray | np.random.SeedSequence
@@ -798,7 +798,7 @@ def reseed_rngs(
798798
rng.set_value(new_rng, borrow=True)
799799

800800

801-
def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]:
801+
def collect_default_updates_inner_fgraph(node: Apply) -> dict[Variable, Variable]:
802802
"""Collect default updates from node with inner fgraph."""
803803
op = node.op
804804
inner_updates = collect_default_updates(
@@ -945,7 +945,7 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
945945
default_update = find_default_update(clients, input_rng)
946946

947947
# Respect default update if provided
948-
if getattr(input_rng, "default_update", None):
948+
if hasattr(input_rng, "default_update") and input_rng.default_update is not None:
949949
rng_updates[input_rng] = input_rng.default_update
950950
else:
951951
if default_update is not None:
@@ -1001,7 +1001,8 @@ def compile_pymc(
10011001

10021002
# We always reseed random variables as this provides RNGs with no chances of collision
10031003
if rng_updates:
1004-
reseed_rngs(rng_updates.keys(), random_seed)
1004+
rngs = cast(list[SharedVariable], list(rng_updates))
1005+
reseed_rngs(rngs, random_seed)
10051006

10061007
# If called inside a model context, see if check_bounds flag is set to False
10071008
try:

scripts/run_mypy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
pymc/model/core.py
4545
pymc/model/fgraph.py
4646
pymc/model/transform/conditioning.py
47-
pymc/printing.py
4847
pymc/pytensorf.py
4948
pymc/sampling/jax.py
5049
"""

0 commit comments

Comments
 (0)