Skip to content

Commit a197b19

Browse files
Even more type fixes
1 parent 372720e commit a197b19

File tree

7 files changed

+19
-17
lines changed

7 files changed

+19
-17
lines changed

pymc/backends/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181

8282
from pymc.backends.mcbackend import init_chain_adapters
8383

84-
TraceOrBackend = BaseTrace | Backend
84+
TraceOrBackend: TypeAlias = BaseTrace | Backend
8585
RunType: TypeAlias = Run
8686
HAS_MCB = True
8787
except ImportError:

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytensor import tensor as pt
2929
from pytensor.compile.builders import OpFromGraph
3030
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
31-
from pytensor.graph.basic import Node, Variable, io_toposort
31+
from pytensor.graph.basic import Apply, Variable, io_toposort
3232
from pytensor.graph.features import ReplaceValidate
3333
from pytensor.graph.rewriting.basic import GraphRewriter, in2out
3434
from pytensor.graph.utils import MetaType
@@ -421,7 +421,7 @@ def __init__(
421421
kwargs.setdefault("strict", True)
422422
super().__init__(*args, **kwargs)
423423

424-
def update(self, node: Node) -> dict[Variable, Variable]:
424+
def update(self, node: Apply) -> dict[Variable, Variable]:
425425
"""Symbolic update expression for input random state variables
426426
427427
Returns a dictionary with the symbolic expressions required for correct updating
@@ -430,7 +430,7 @@ def update(self, node: Node) -> dict[Variable, Variable]:
430430
"""
431431
return collect_default_updates_inner_fgraph(node)
432432

433-
def batch_ndim(self, node: Node) -> int:
433+
def batch_ndim(self, node: Apply) -> int:
434434
"""Number of dimensions of the distribution's batch shape."""
435435
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
436436
return out_ndim - self.ndim_supp

pymc/distributions/mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytensor
1919
import pytensor.tensor as pt
2020

21-
from pytensor.graph.basic import Node, equal_computations
21+
from pytensor.graph.basic import Apply, equal_computations
2222
from pytensor.tensor import TensorVariable
2323
from pytensor.tensor.random.op import RandomVariable
2424
from pytensor.tensor.random.utils import normalize_size_param
@@ -156,7 +156,7 @@ def _resize_components(cls, size, *components):
156156

157157
return [change_dist_size(component, size) for component in components]
158158

159-
def update(self, node: Node):
159+
def update(self, node: Apply):
160160
# Update for the internal mix_indexes RV
161161
return {node.inputs[0]: node.outputs[0]}
162162

pymc/logprob/order.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
# SOFTWARE.
3636

3737

38+
from typing import cast
39+
3840
import pytensor.tensor as pt
3941

40-
from pytensor.graph.basic import Node
42+
from pytensor.graph.basic import Apply
4143
from pytensor.graph.fg import FunctionGraph
4244
from pytensor.graph.rewriting.basic import node_rewriter
4345
from pytensor.tensor.elemwise import Elemwise
@@ -72,15 +74,15 @@ class MeasurableMaxDiscrete(Max):
7274

7375

7476
@node_rewriter([Max])
75-
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
77+
def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
7678
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
7779
if rv_map_feature is None:
7880
return None # pragma: no cover
7981

8082
if isinstance(node.op, MeasurableMax):
8183
return None # pragma: no cover
8284

83-
base_var = node.inputs[0]
85+
base_var = cast(TensorVariable, node.inputs[0])
8486

8587
if base_var.owner is None:
8688
return None
@@ -104,6 +106,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariabl
104106
return None
105107

106108
# distinguish measurable discrete and continuous (because logprob is different)
109+
measurable_max: Max
107110
if base_var.owner.op.dtype.startswith("int"):
108111
measurable_max = MeasurableMaxDiscrete(list(axis))
109112
else:
@@ -173,7 +176,7 @@ class MeasurableDiscreteMaxNeg(Max):
173176

174177

175178
@node_rewriter(tracks=[Max])
176-
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
179+
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
177180
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
178181

179182
if rv_map_feature is None:
@@ -182,7 +185,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar
182185
if isinstance(node.op, MeasurableMaxNeg):
183186
return None # pragma: no cover
184187

185-
base_var = node.inputs[0]
188+
base_var = cast(TensorVariable, node.inputs[0])
186189

187190
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
188191
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
@@ -213,6 +216,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVar
213216
return None
214217

215218
# distinguish measurable discrete and continuous (because logprob is different)
219+
measurable_min: Max
216220
if base_rv.owner.op.dtype.startswith("int"):
217221
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
218222
else:

pymc/pytensorf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,15 +926,15 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
926926
if inputs is None:
927927
inputs = []
928928

929-
outputs = makeiter(outputs)
930-
fg = FunctionGraph(outputs=outputs, clone=False)
929+
outs = makeiter(outputs)
930+
fg = FunctionGraph(outputs=outs, clone=False)
931931
clients = fg.clients
932932

933933
rng_updates = {}
934934
# Iterate over input RNGs. Only consider shared RNGs if `must_be_shared==True`
935935
for input_rng in (
936936
inp
937-
for inp in graph_inputs(outputs, blockers=inputs)
937+
for inp in graph_inputs(outs, blockers=inputs)
938938
if (
939939
(not must_be_shared or isinstance(inp, SharedVariable))
940940
and isinstance(inp.type, RandomType)

pymc/smc/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def sample_smc(
225225
trace = MultiTrace(traces)
226226

227227
_t_sampling = time.time() - t1
228-
sample_stats, idata = _save_sample_stats(
228+
_, idata = _save_sample_stats(
229229
sample_settings,
230230
sample_stats,
231231
chains,

scripts/run_mypy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
pymc/distributions/continuous.py
2626
pymc/distributions/dist_math.py
2727
pymc/distributions/distribution.py
28-
pymc/distributions/mixture.py
2928
pymc/distributions/multivariate.py
3029
pymc/distributions/timeseries.py
3130
pymc/distributions/truncated.py
@@ -34,7 +33,6 @@
3433
pymc/logprob/censoring.py
3534
pymc/logprob/basic.py
3635
pymc/logprob/mixture.py
37-
pymc/logprob/order.py
3836
pymc/logprob/rewriting.py
3937
pymc/logprob/scan.py
4038
pymc/logprob/tensor.py

0 commit comments

Comments
 (0)