From 3be8b1abe4e38920a47a81db8a8636edec310290 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Jan 2025 13:07:20 -0800 Subject: [PATCH 1/3] Change flattening behaviour in Kron --- timm/optim/kron.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/timm/optim/kron.py b/timm/optim/kron.py index e01c9885be..9f4e496527 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -94,7 +94,8 @@ class Kron(torch.optim.Optimizer): mu_dtype: Dtype of the momentum accumulator. precond_dtype: Dtype of the preconditioner. decoupled_decay: AdamW style decoupled weight decay - flatten_dim: Flatten dim >= 2 instead of relying on expressions + flatten: Flatten dimensions instead of fully relying on expressions for higher rank params + flatten_start_end: Range of dimensions to flatten, defaults to (2, -1). deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work """ @@ -114,7 +115,8 @@ def __init__( mu_dtype: Optional[torch.dtype] = None, precond_dtype: Optional[torch.dtype] = None, decoupled_decay: bool = False, - flatten_dim: bool = False, + flatten: bool = False, + flatten_start_end: Tuple[int, int] = (2, -1), deterministic: bool = False, ): if not has_opt_einsum: @@ -141,7 +143,8 @@ def __init__( mu_dtype=mu_dtype, precond_dtype=precond_dtype, decoupled_decay=decoupled_decay, - flatten_dim=flatten_dim, + flatten=flatten, + flatten_start_end=flatten_start_end, ) super(Kron, self).__init__(params, defaults) @@ -229,8 +232,11 @@ def step(self, closure=None): grad = p.grad state = self.state[p] - if group['flatten_dim']: - grad = grad.view(grad.size(0), -1) + + flattened = False + if group['flatten']: + grad = safe_flatten(grad, *group["flatten_start_end"]) + flattened = True if len(state) == 0: state["step"] = 0 @@ -341,7 +347,7 @@ def step(self, closure=None): # RMS of pre_grad should be 1.0, so let's cap at 1.1 pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0)) - if group['flatten_dim']: + if flattened: pre_grad = pre_grad.view(p.shape) # Apply weight decay @@ -361,6 +367,20 @@ def step(self, closure=None): return loss +def safe_flatten(tensor, start_dim=0, end_dim=-1): + ndim = tensor.ndim + + # Convert negative end_dim to positive and clip to end + end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1) + + # If tensor has fewer dims than start_dim or start > end, return tensor as is + if ndim <= start_dim or start_dim > end_dim: + return tensor + + # Now safe to flatten + return tensor.flatten(start_dim, end_dim) + + def _init_Q_exprs( t, scale, From 5940cc167f7426054a7f69c4358c7f2ac2655d5d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Jan 2025 13:13:49 -0800 Subject: [PATCH 2/3] Change start/end args --- timm/optim/kron.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 9f4e496527..25c1b04716 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -116,7 +116,8 @@ def __init__( precond_dtype: Optional[torch.dtype] = None, decoupled_decay: bool = False, flatten: bool = False, - flatten_start_end: Tuple[int, int] = (2, -1), + flatten_start_dim: int = 2, + flatten_end_dim: int = -1, deterministic: bool = False, ): if not has_opt_einsum: @@ -144,7 +145,8 @@ def __init__( precond_dtype=precond_dtype, decoupled_decay=decoupled_decay, flatten=flatten, - flatten_start_end=flatten_start_end, + flatten_start_dim=flatten_start_dim, + flatten_end_dim=flatten_end_dim, ) super(Kron, self).__init__(params, defaults) @@ -235,7 +237,7 @@ def step(self, closure=None): flattened = False if group['flatten']: - grad = safe_flatten(grad, *group["flatten_start_end"]) + grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"]) flattened = True if len(state) == 0: From 5f85f8eefa6f205d75a586aacb8399f36d530612 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Jan 2025 15:42:27 -0800 Subject: [PATCH 3/3] Fix comment, add 'stochastic weight decay' idea because why not --- timm/optim/kron.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 25c1b04716..533354ecc8 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer): precond_dtype: Dtype of the preconditioner. decoupled_decay: AdamW style decoupled weight decay flatten: Flatten dimensions instead of fully relying on expressions for higher rank params - flatten_start_end: Range of dimensions to flatten, defaults to (2, -1). + flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets. + flatten_end_dim: End of flatten range, defaults to -1. + stochastic_weight_decay: Enable random modulation of weight decay deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work """ @@ -118,6 +120,7 @@ def __init__( flatten: bool = False, flatten_start_dim: int = 2, flatten_end_dim: int = -1, + stochastic_weight_decay: bool = False, deterministic: bool = False, ): if not has_opt_einsum: @@ -147,6 +150,7 @@ def __init__( flatten=flatten, flatten_start_dim=flatten_start_dim, flatten_end_dim=flatten_end_dim, + stochastic_weight_decay=stochastic_weight_decay, ) super(Kron, self).__init__(params, defaults) @@ -353,11 +357,15 @@ def step(self, closure=None): pre_grad = pre_grad.view(p.shape) # Apply weight decay - if group["weight_decay"] != 0: + weight_decay = group["weight_decay"] + if weight_decay != 0: + if group["stochastic_weight_decay"]: + weight_decay = 2 * self.rng.random() * weight_decay + if group["decoupled_decay"]: - p.mul_(1. - group["lr"] * group["weight_decay"]) + p.mul_(1. - group["lr"] * weight_decay) else: - pre_grad.add_(p, alpha=group["weight_decay"]) + pre_grad.add_(p, alpha=weight_decay) # Update parameters p.add_(pre_grad, alpha=-group["lr"])