@@ -142,9 +142,6 @@ class ModernBertConfig(PretrainedConfig):
142
142
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
143
143
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
144
144
be faster in some scenarios.
145
- repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
146
- When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only
147
- applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
148
145
149
146
Examples:
150
147
@@ -200,7 +197,6 @@ def __init__(
200
197
sparse_prediction = False ,
201
198
sparse_pred_ignore_index = - 100 ,
202
199
reference_compile = None ,
203
- repad_logits_with_grad = False ,
204
200
** kwargs ,
205
201
):
206
202
super ().__init__ (
@@ -240,7 +236,6 @@ def __init__(
240
236
self .sparse_prediction = sparse_prediction
241
237
self .sparse_pred_ignore_index = sparse_pred_ignore_index
242
238
self .reference_compile = reference_compile
243
- self .repad_logits_with_grad = repad_logits_with_grad
244
239
245
240
if self .classifier_pooling not in ["cls" , "mean" ]:
246
241
raise ValueError (
@@ -1262,7 +1257,7 @@ def forward(
1262
1257
loss = self .loss_function (logits , labels , vocab_size = self .config .vocab_size )
1263
1258
1264
1259
if self .config ._attn_implementation == "flash_attention_2" :
1265
- with nullcontext () if self .config . repad_logits_with_grad or labels is None else torch .no_grad ():
1260
+ with nullcontext () if self .training or labels is None else torch .no_grad ():
1266
1261
logits = _pad_modernbert_output (inputs = logits , indices = indices , batch = batch_size , seqlen = seq_len )
1267
1262
1268
1263
if not return_dict :
0 commit comments