Skip to content

Commit 6616e69

Browse files
sguggern1t0
andauthored
Expand documentation of UnigramTrainer (#770)
* Expand documentation of UnigramTrainer * Put doc at the source * Add signature * make style Co-authored-by: Anthony Moi <[email protected]>
1 parent da4c7b1 commit 6616e69

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

bindings/python/py_src/tokenizers/trainers/__init__.pyi

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,32 @@ class UnigramTrainer(Trainer):
7272
if not seen in the training dataset.
7373
If the strings contain more than one character, only the first one
7474
is kept.
75+
76+
shrinking_factor (:obj:`float`):
77+
The shrinking factor used at each step of the training to prune the
78+
vocabulary.
79+
80+
unk_token (:obj:`str`):
81+
The token used for out-of-vocabulary tokens.
82+
83+
max_piece_length (:obj:`int`):
84+
The maximum length of a given token.
85+
86+
n_sub_iterations (:obj:`int`):
87+
The number of iterations of the EM algorithm to perform before
88+
pruning the vocabulary.
7589
"""
7690

77-
def __init__(self, vocab_size=8000, show_progress=True, special_tokens=[]):
91+
def __init__(
92+
self,
93+
vocab_size=8000,
94+
show_progress=True,
95+
special_tokens=[],
96+
shrinking_factor=0.75,
97+
unk_token=None,
98+
max_piece_length=16,
99+
n_sub_iterations=2,
100+
):
78101
pass
79102

80103
class WordLevelTrainer(Trainer):

bindings/python/src/trainers.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,22 @@ impl PyWordLevelTrainer {
669669
/// if not seen in the training dataset.
670670
/// If the strings contain more than one character, only the first one
671671
/// is kept.
672+
///
673+
/// shrinking_factor (:obj:`float`):
674+
/// The shrinking factor used at each step of the training to prune the
675+
/// vocabulary.
676+
///
677+
/// unk_token (:obj:`str`):
678+
/// The token used for out-of-vocabulary tokens.
679+
///
680+
/// max_piece_length (:obj:`int`):
681+
/// The maximum length of a given token.
682+
///
683+
/// n_sub_iterations (:obj:`int`):
684+
/// The number of iterations of the EM algorithm to perform before
685+
/// pruning the vocabulary.
672686
#[pyclass(extends=PyTrainer, module = "tokenizers.trainers", name=UnigramTrainer)]
673-
#[text_signature = "(self, vocab_size=8000, show_progress=True, special_tokens= [])"]
687+
#[text_signature = "(self, vocab_size=8000, show_progress=True, special_tokens=[], shrinking_factor=0.75, unk_token=None, max_piece_length=16, n_sub_iterations=2)"]
674688
pub struct PyUnigramTrainer {}
675689
#[pymethods]
676690
impl PyUnigramTrainer {

0 commit comments

Comments
 (0)