Skip to content

Commit e025328

Browse files
authored
simplify RNG
1 parent 6367267 commit e025328

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

timm/optim/kron.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,11 @@ def step(self, closure=None):
305305
if do_update:
306306
exprA, exprGs, _ = exprs
307307
Q = state["Q"]
308-
if self.deterministic is None:
308+
if self.deterministic:
309309
torch_rng = torch.Generator(device=V.device).manual_seed(self.rng.randint(0, 2 ** 31))
310-
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device='cpu')
311-
V = V.to(debiased_momentum.device)
312310
else:
313-
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
311+
torch_rng = None
312+
V = torch.randn(debiased_momentum.shape, generator=torch_rng, dtype=precond_dtype, device=debiased_momentum.device)
314313
G = debiased_momentum if momentum_into_precond_update else grad
315314

316315
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)

0 commit comments

Comments
 (0)