From d2cbc7fb21e473710f20a57ce07a27a0eaca272b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 21 Feb 2025 19:24:05 +0100 Subject: [PATCH] adapt_input_conv: add type hints --- timm/models/_manipulate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/_manipulate.py b/timm/models/_manipulate.py index f40ff9ac47..c6042712d9 100644 --- a/timm/models/_manipulate.py +++ b/timm/models/_manipulate.py @@ -8,6 +8,7 @@ import torch import torch.utils.checkpoint from torch import nn as nn +from torch import Tensor from timm.layers import use_reentrant_ckpt @@ -284,7 +285,7 @@ def forward(_x): return x -def adapt_input_conv(in_chans, conv_weight): +def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor: conv_type = conv_weight.dtype conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU O, I, J, K = conv_weight.shape