7
7
from collections import OrderedDict
8
8
9
9
10
- def get_byte_pair_counts (ids : list [int ]):
10
+ def get_byte_pair_counts (ids : list [int ]) -> dict :
11
11
"""Count consecutive byte-pairs of an encoded string.
12
12
13
13
>>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46]
@@ -23,7 +23,7 @@ def get_byte_pair_counts(ids: list[int]):
23
23
return counts
24
24
25
25
26
- def merge (ids : list [int ], pair : tuple , idx : int ):
26
+ def merge (ids : list [int ], pair : tuple , idx : int ) -> list [ int ] :
27
27
"""Replace most occurring byte pair with new byte that is not used
28
28
in the data. For utf-8 encoding, we start with 256 as the new byte
29
29
@@ -48,12 +48,12 @@ def merge(ids: list[int], pair: tuple, idx: int):
48
48
class Tokenizer :
49
49
"""Tokenize a string using the byte-pair encoding algorithm"""
50
50
51
- def __init__ (self , num_merges : int = 20 , verbose : bool = False ):
51
+ def __init__ (self , num_merges : int = 20 , verbose : bool = False ) -> None :
52
52
self .num_merges = num_merges
53
53
self .merges : dict = {}
54
54
self .verbose = verbose
55
55
56
- def encode (self , text : str ):
56
+ def encode (self , text : str ) -> list [ int ] :
57
57
"""Convert a string to tokens (bytes)
58
58
59
59
>>> t = Tokenizer()
@@ -80,7 +80,7 @@ def encode(self, text: str):
80
80
# start merging most frequently occurring byte pairs
81
81
for i in range (num_merges ):
82
82
counts = get_byte_pair_counts (ids )
83
- pair = max (counts , key = counts .get )
83
+ pair = max (counts , key = counts .__getitem__ )
84
84
85
85
if counts [pair ] == 1 :
86
86
continue
@@ -93,7 +93,7 @@ def encode(self, text: str):
93
93
94
94
return ids
95
95
96
- def decode (self , ids : list [int ]):
96
+ def decode (self , ids : list [int ]) -> str :
97
97
"""Convert a list of tokens to the original string
98
98
99
99
>>> t = Tokenizer()
0 commit comments