Skip to content

Commit cedcb4e

Browse files
committed
Avoid needing repad_logits_with_grad, always repad with grads when training
I'm not 100% that the conditional with "or labels is None" makes sense though - not sure what the intention is there. Perhaps we can remove that?
1 parent ab11657 commit cedcb4e

File tree

3 files changed

+2
-12
lines changed

3 files changed

+2
-12
lines changed

src/transformers/models/modernbert/configuration_modernbert.py

-5
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ class ModernBertConfig(PretrainedConfig):
109109
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
110110
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
111111
be faster in some scenarios.
112-
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
113-
When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only
114-
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
115112
116113
Examples:
117114
@@ -167,7 +164,6 @@ def __init__(
167164
sparse_prediction=False,
168165
sparse_pred_ignore_index=-100,
169166
reference_compile=None,
170-
repad_logits_with_grad=False,
171167
**kwargs,
172168
):
173169
super().__init__(
@@ -207,7 +203,6 @@ def __init__(
207203
self.sparse_prediction = sparse_prediction
208204
self.sparse_pred_ignore_index = sparse_pred_ignore_index
209205
self.reference_compile = reference_compile
210-
self.repad_logits_with_grad = repad_logits_with_grad
211206

212207
if self.classifier_pooling not in ["cls", "mean"]:
213208
raise ValueError(

src/transformers/models/modernbert/modeling_modernbert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,7 @@ def forward(
11041104
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
11051105

11061106
if self.config._attn_implementation == "flash_attention_2":
1107-
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1107+
with nullcontext() if self.training or labels is None else torch.no_grad():
11081108
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
11091109

11101110
if not return_dict:

src/transformers/models/modernbert/modular_modernbert.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,6 @@ class ModernBertConfig(PretrainedConfig):
142142
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
143143
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
144144
be faster in some scenarios.
145-
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
146-
When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only
147-
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
148145
149146
Examples:
150147
@@ -200,7 +197,6 @@ def __init__(
200197
sparse_prediction=False,
201198
sparse_pred_ignore_index=-100,
202199
reference_compile=None,
203-
repad_logits_with_grad=False,
204200
**kwargs,
205201
):
206202
super().__init__(
@@ -240,7 +236,6 @@ def __init__(
240236
self.sparse_prediction = sparse_prediction
241237
self.sparse_pred_ignore_index = sparse_pred_ignore_index
242238
self.reference_compile = reference_compile
243-
self.repad_logits_with_grad = repad_logits_with_grad
244239

245240
if self.classifier_pooling not in ["cls", "mean"]:
246241
raise ValueError(
@@ -1262,7 +1257,7 @@ def forward(
12621257
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
12631258

12641259
if self.config._attn_implementation == "flash_attention_2":
1265-
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1260+
with nullcontext() if self.training or labels is None else torch.no_grad():
12661261
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
12671262

12681263
if not return_dict:

0 commit comments

Comments
 (0)