diff --git a/timm/models/repvit.py b/timm/models/repvit.py index b0199b8986..43e35be900 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -176,6 +176,7 @@ def __init__(self, dim, num_classes, distillation=False): super().__init__() self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.distillation = distillation + self.num_classes=num_classes if distillation: self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()