diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py index f40ff9ac47..c6042712d9 100644 --- a/timm/models/_manipulate.py +++ b/timm/models/_manipulate.py @@ -8,6 +8,7 @@ import torch import torch.utils.checkpoint from torch import nn as nn +from torch import Tensor from timm.layers import use_reentrant_ckpt @@ -284,7 +285,7 @@ def forward(_x): return x -def adapt_input_conv(in_chans, conv_weight): +def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor: conv_type = conv_weight.dtype conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU O, I, J, K = conv_weight.shape