Skip to content

Commit e467bb9

Browse files
asinghgabamichaelosthege
authored andcommitted
Removed unused imports using autoflake
1 parent 2a5f65e commit e467bb9

File tree

2 files changed

+42
-54
lines changed

2 files changed

+42
-54
lines changed

pymc3/distributions/posterior_predictive.py

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,7 @@
77

88
from collections import UserDict
99
from contextlib import AbstractContextManager
10-
from typing import (
11-
TYPE_CHECKING,
12-
Any,
13-
Callable,
14-
Dict,
15-
List,
16-
Optional,
17-
Set,
18-
Tuple,
19-
Union,
20-
cast,
21-
overload,
22-
)
10+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, cast, overload
2311

2412
import numpy as np
2513
import theano.graph.basic
@@ -69,15 +57,15 @@ class _TraceDict(UserDict):
6957
~~~~~~~~~~
7058
varnames: list of strings"""
7159

72-
varnames: List[str]
60+
varnames: list[str]
7361
_len: int
7462
data: Point
7563

7664
def __init__(
7765
self,
78-
point_list: Optional[List[Point]] = None,
79-
multi_trace: Optional[MultiTrace] = None,
80-
dict_: Optional[Point] = None,
66+
point_list: list[Point] | None = None,
67+
multi_trace: MultiTrace | None = None,
68+
dict_: Point | None = None,
8169
):
8270
""""""
8371
if multi_trace:
@@ -134,11 +122,11 @@ def apply_slice(arr: np.ndarray) -> np.ndarray:
134122
return _TraceDict(dict_=sliced_dict)
135123

136124
@overload
137-
def __getitem__(self, item: Union[str, HasName]) -> np.ndarray:
125+
def __getitem__(self, item: str | HasName) -> np.ndarray:
138126
...
139127

140128
@overload
141-
def __getitem__(self, item: Union[slice, int]) -> _TraceDict:
129+
def __getitem__(self, item: slice | int) -> _TraceDict:
142130
...
143131

144132
def __getitem__(self, item):
@@ -155,13 +143,13 @@ def __getitem__(self, item):
155143

156144

157145
def fast_sample_posterior_predictive(
158-
trace: Union[MultiTrace, Dataset, InferenceData, List[Dict[str, np.ndarray]]],
159-
samples: Optional[int] = None,
160-
model: Optional[Model] = None,
161-
var_names: Optional[List[str]] = None,
146+
trace: MultiTrace | Dataset | InferenceData | list[dict[str, np.ndarray]],
147+
samples: int | None = None,
148+
model: Model | None = None,
149+
var_names: list[str] | None = None,
162150
keep_size: bool = False,
163151
random_seed=None,
164-
) -> Dict[str, np.ndarray]:
152+
) -> dict[str, np.ndarray]:
165153
"""Generate posterior predictive samples from a model given a trace.
166154
167155
This is a vectorized alternative to the standard ``sample_posterior_predictive`` function.
@@ -250,7 +238,7 @@ def fast_sample_posterior_predictive(
250238

251239
assert isinstance(_trace, _TraceDict)
252240

253-
_samples: List[int] = []
241+
_samples: list[int] = []
254242
# temporary replacement for more complicated logic.
255243
max_samples: int = len_trace
256244
if samples is None or samples == max_samples:
@@ -289,7 +277,7 @@ def fast_sample_posterior_predictive(
289277
_ETPParent = UserDict
290278

291279
class _ExtendableTrace(_ETPParent):
292-
def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
280+
def extend_trace(self, trace: dict[str, np.ndarray]) -> None:
293281
for k, v in trace.items():
294282
if k in self.data:
295283
self.data[k] = np.concatenate((self.data[k], v))
@@ -301,7 +289,7 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
301289
strace = _trace if s == len_trace else _trace[slice(0, s)]
302290
try:
303291
values = posterior_predictive_draw_values(cast(List[Any], vars), strace, s)
304-
new_trace: Dict[str, np.ndarray] = {k.name: v for (k, v) in zip(vars, values)}
292+
new_trace: dict[str, np.ndarray] = {k.name: v for (k, v) in zip(vars, values)}
305293
ppc_trace.extend_trace(new_trace)
306294
except KeyboardInterrupt:
307295
pass
@@ -313,8 +301,8 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
313301

314302

315303
def posterior_predictive_draw_values(
316-
vars: List[Any], trace: _TraceDict, samples: int
317-
) -> List[np.ndarray]:
304+
vars: list[Any], trace: _TraceDict, samples: int
305+
) -> list[np.ndarray]:
318306
with _PosteriorPredictiveSampler(vars, trace, samples, None) as sampler:
319307
return sampler.draw_values()
320308

@@ -323,25 +311,25 @@ class _PosteriorPredictiveSampler(AbstractContextManager):
323311
"""The process of posterior predictive sampling is quite complicated so this provides a central data store."""
324312

325313
# inputs
326-
vars: List[Any]
314+
vars: list[Any]
327315
trace: _TraceDict
328316
samples: int
329-
size: Optional[int] # not supported!
317+
size: int | None # not supported!
330318

331319
# other slots
332320
logger: logging.Logger
333321

334322
# for the search
335-
evaluated: Dict[int, np.ndarray]
336-
symbolic_params: List[Tuple[int, Any]]
323+
evaluated: dict[int, np.ndarray]
324+
symbolic_params: list[tuple[int, Any]]
337325

338326
# set by make_graph...
339-
leaf_nodes: Dict[str, Any]
340-
named_nodes_parents: Dict[str, Any]
341-
named_nodes_children: Dict[str, Any]
327+
leaf_nodes: dict[str, Any]
328+
named_nodes_parents: dict[str, Any]
329+
named_nodes_children: dict[str, Any]
342330
_tok: contextvars.Token
343331

344-
def __init__(self, vars, trace: _TraceDict, samples, model: Optional[Model], size=None):
332+
def __init__(self, vars, trace: _TraceDict, samples, model: Model | None, size=None):
345333
if size is not None:
346334
raise NotImplementedError(
347335
"sample_posterior_predictive does not support the size argument at this time."
@@ -361,7 +349,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
361349
vectorized_ppc.reset(self._tok)
362350
return False
363351

364-
def draw_values(self) -> List[np.ndarray]:
352+
def draw_values(self) -> list[np.ndarray]:
365353
vars = self.vars
366354
trace = self.trace
367355
samples = self.samples
@@ -438,8 +426,8 @@ def draw_values(self) -> List[np.ndarray]:
438426
# the below makes sure the graph is evaluated in order
439427
# test_distributions_random::TestDrawValues::test_draw_order fails without it
440428
# The remaining params that must be drawn are all hashable
441-
to_eval: Set[int] = set()
442-
missing_inputs: Set[int] = {j for j, p in self.symbolic_params}
429+
to_eval: set[int] = set()
430+
missing_inputs: set[int] = {j for j, p in self.symbolic_params}
443431

444432
while to_eval or missing_inputs:
445433
if to_eval == missing_inputs:
@@ -477,19 +465,19 @@ def init(self) -> None:
477465
from the posterior predictive distribution. Notably it initializes the
478466
``_DrawValuesContext`` bookkeeping object and evaluates the "fast drawable"
479467
parts of the model."""
480-
vars: List[Any] = self.vars
468+
vars: list[Any] = self.vars
481469
trace: _TraceDict = self.trace
482470
samples: int = self.samples
483-
leaf_nodes: Dict[str, Any]
484-
named_nodes_parents: Dict[str, Any]
485-
named_nodes_children: Dict[str, Any]
471+
leaf_nodes: dict[str, Any]
472+
named_nodes_parents: dict[str, Any]
473+
named_nodes_children: dict[str, Any]
486474

487475
# initialization phase
488476
context = _DrawValuesContext.get_context()
489477
assert isinstance(context, _DrawValuesContext)
490478
with context:
491479
drawn = context.drawn_vars
492-
evaluated: Dict[int, Any] = {}
480+
evaluated: dict[int, Any] = {}
493481
symbolic_params = []
494482
for i, var in enumerate(vars):
495483
if is_fast_drawable(var):
@@ -534,7 +522,7 @@ def make_graph(self) -> None:
534522
else:
535523
self.named_nodes_children[k].update(nnc[k])
536524

537-
def draw_value(self, param, trace: Optional[_TraceDict] = None, givens=None):
525+
def draw_value(self, param, trace: _TraceDict | None = None, givens=None):
538526
"""Draw a set of random values from a distribution or return a constant.
539527
540528
Parameters
@@ -559,7 +547,7 @@ def random_sample(
559547
param,
560548
point: _TraceDict,
561549
size: int,
562-
shape: Tuple[int, ...],
550+
shape: tuple[int, ...],
563551
) -> np.ndarray:
564552
val = meth(point=point, size=size)
565553
try:
@@ -591,7 +579,7 @@ def random_sample(
591579
elif hasattr(param, "random") and param.random is not None:
592580
model = modelcontext(None)
593581
assert isinstance(model, Model)
594-
shape: Tuple[int, ...] = tuple(_param_shape(param, model))
582+
shape: tuple[int, ...] = tuple(_param_shape(param, model))
595583
return random_sample(param.random, param, point=trace, size=samples, shape=shape)
596584
elif (
597585
hasattr(param, "distribution")
@@ -602,7 +590,7 @@ def random_sample(
602590
# shape inspection for ObservedRV
603591
dist_tmp = param.distribution
604592
try:
605-
distshape: Tuple[int, ...] = tuple(param.observations.shape.eval())
593+
distshape: tuple[int, ...] = tuple(param.observations.shape.eval())
606594
except AttributeError:
607595
distshape = tuple(param.observations.shape)
608596

@@ -689,7 +677,7 @@ def random_sample(
689677
raise ValueError("Unexpected type in draw_value: %s" % type(param))
690678

691679

692-
def _param_shape(var_desig, model: Model) -> Tuple[int, ...]:
680+
def _param_shape(var_desig, model: Model) -> tuple[int, ...]:
693681
if isinstance(var_desig, str):
694682
v = model[var_desig]
695683
else:

pymc3/plots/posteriorplot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import warnings
1818

19-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
19+
from typing import TYPE_CHECKING, Any, Callable
2020

2121
import matplotlib.pyplot as plt
2222
import numpy as np
@@ -28,9 +28,9 @@
2828

2929

3030
def plot_posterior_predictive_glm(
31-
trace: Union[InferenceData, MultiTrace],
32-
eval: Optional[np.ndarray] = None,
33-
lm: Optional[Callable] = None,
31+
trace: InferenceData | MultiTrace,
32+
eval: np.ndarray | None = None,
33+
lm: Callable | None = None,
3434
samples: int = 30,
3535
**kwargs: Any
3636
) -> None:

0 commit comments

Comments
 (0)