Skip to content

Commit 9b78a6e

Browse files
fix: only keep stop sequence buffer if we have some
1 parent 80a6920 commit 9b78a6e

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

server/text_generation_server/utils/tokens.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self.stop_sequence_criterias = stop_sequence_criterias
113113
self.max_new_tokens = max_new_tokens
114114
self.current_tokens = 0
115-
self.current_output = "test"
115+
self.current_output = ""
116116
self.ignore_eos_token = ignore_eos_token
117117

118118
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
@@ -123,14 +123,15 @@ def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[st
123123
if not self.ignore_eos_token and last_token == self.eos_token_id:
124124
return True, FinishReason.FINISH_REASON_EOS_TOKEN
125125

126-
self.current_output += last_output
127-
# There is no need to keep an output that is too long
128-
if len(self.current_output) > 300:
129-
# Slice to -200 to avoid doing it all the time
130-
self.current_output = self.current_output[-200:]
131-
for stop_sequence_criteria in self.stop_sequence_criterias:
132-
if stop_sequence_criteria(self.current_output):
133-
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
126+
if self.stop_sequence_criterias:
127+
self.current_output += last_output
128+
# There is no need to keep an output that is too long
129+
if len(self.current_output) > 300:
130+
# Slice to -200 to avoid doing it all the time
131+
self.current_output = self.current_output[-200:]
132+
for stop_sequence_criteria in self.stop_sequence_criterias:
133+
if stop_sequence_criteria(self.current_output):
134+
return True, FinishReason.FINISH_REASON_STOP_SEQUENCE
134135

135136
return False, None
136137

0 commit comments

Comments
 (0)