Skip to content

Commit 790475a

Browse files
committed
feat: replace dict get with dunder method
1 parent 5b58203 commit 790475a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

strings/bpe_tokenizer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import OrderedDict
88

99

10-
def get_byte_pair_counts(ids: list[int]):
10+
def get_byte_pair_counts(ids: list[int]) -> dict:
1111
"""Count consecutive byte-pairs of an encoded string.
1212
1313
>>> ids = [73, 32, 97, 109, 32, 74, 111, 110, 83, 110, 111, 119, 46]
@@ -23,7 +23,7 @@ def get_byte_pair_counts(ids: list[int]):
2323
return counts
2424

2525

26-
def merge(ids: list[int], pair: tuple, idx: int):
26+
def merge(ids: list[int], pair: tuple, idx: int) -> list[int]:
2727
"""Replace most occurring byte pair with new byte that is not used
2828
in the data. For utf-8 encoding, we start with 256 as the new byte
2929
@@ -48,12 +48,12 @@ def merge(ids: list[int], pair: tuple, idx: int):
4848
class Tokenizer:
4949
"""Tokenize a string using the byte-pair encoding algorithm"""
5050

51-
def __init__(self, num_merges: int = 20, verbose: bool = False):
51+
def __init__(self, num_merges: int = 20, verbose: bool = False) -> None:
5252
self.num_merges = num_merges
5353
self.merges: dict = {}
5454
self.verbose = verbose
5555

56-
def encode(self, text: str):
56+
def encode(self, text: str) -> list[int]:
5757
"""Convert a string to tokens (bytes)
5858
5959
>>> t = Tokenizer()
@@ -80,7 +80,7 @@ def encode(self, text: str):
8080
# start merging most frequently occurring byte pairs
8181
for i in range(num_merges):
8282
counts = get_byte_pair_counts(ids)
83-
pair = max(counts, key=counts.get)
83+
pair = max(counts, key=counts.__getitem__)
8484

8585
if counts[pair] == 1:
8686
continue
@@ -93,7 +93,7 @@ def encode(self, text: str):
9393

9494
return ids
9595

96-
def decode(self, ids: list[int]):
96+
def decode(self, ids: list[int]) -> str:
9797
"""Convert a list of tokens to the original string
9898
9999
>>> t = Tokenizer()

0 commit comments

Comments
 (0)