@@ -7,7 +7,7 @@ def should_process_prefill_or_recompute(llmactor, env):
7
7
"""Check if the system should process prefill or recompute based on queue sizes and memory constraints."""
8
8
return can_prefill_items (llmactor , env )
9
9
10
- def can_prefill_items (llmactor , env , ):
10
+ def can_prefill_items (llmactor , env ):
11
11
"""Are there items I can prefill?"""
12
12
prefill_batch_size = 0
13
13
num_new_seq = 0
@@ -25,8 +25,8 @@ def can_prefill_items(llmactor, env, ):
25
25
break
26
26
27
27
return True
28
- while llmactor .get_prefill_queue_size () > 0 :
29
28
29
+ while llmactor .get_prefill_queue_size () > 0 :
30
30
oldest_item = llmactor .prefill_store .items [0 ]
31
31
oldest_item_len = oldest_item .input_size + oldest_item .output_size - oldest_item .output_size_remaining
32
32
oldest_item_input_len = oldest_item .input_size
@@ -42,8 +42,7 @@ def can_prefill_items(llmactor, env, ):
42
42
43
43
return False
44
44
45
-
46
- def fetch_prefill_items (llmactor , env , ):
45
+ def fetch_prefill_items (llmactor , env ):
47
46
"""Fetch items to prefill if there is memory either from recompute (p0) or from prefill (p1)"""
48
47
items_to_prefill = []
49
48
prefill_batch_size = 0
@@ -83,21 +82,19 @@ def fetch_prefill_items(llmactor, env, ):
83
82
msg = yield llmactor .prefill_store .get ()
84
83
items_to_prefill .append (msg )
85
84
86
-
87
85
return items_to_prefill
88
86
89
- def process_prefill_items ( llmactor , env , items_to_prefill , req_dict_prefill , req_dict , logging = False ):
87
+ def process_prefill_items (llmactor , env , items_to_prefill , req_dict_prefill , req_dict , logging = False ):
90
88
"""Process prefill items, updating times and managing item states."""
91
89
prefill_len = np .sum ([x .input_size + x .output_size - x .output_size_remaining for x in items_to_prefill ])
92
- prefill_delay = calculate_prefill_delay (prefill_len , len (items_to_prefill ), TOKENIZE_LATENCY_CONST , PREFILL_LATENCY_CONST_2 , PREFILL_LATENCY_CONST_1 , PREFILL_LATENCY_CONST_0 , PREFILL_LATENCY_CONST_MIN )
93
-
90
+ prefill_delay = calculate_prefill_delay (prefill_len , len (items_to_prefill ), TOKENIZE_LATENCY_CONST , PREFILL_LATENCY_CONST_2 , PREFILL_LATENCY_CONST_1 , PREFILL_LATENCY_CONST_0 , PREFILL_LATENCY_CONST_MIN )
94
91
95
92
for item in items_to_prefill :
96
- #lora stuff
93
+ # lora stuff
97
94
if item .lora is not None :
98
- if item .lora not in llmactor .lora_loaded :
95
+ if item .lora not in llmactor .lora_loaded :
99
96
llmactor .lora_loaded .add (item .lora )
100
- llmactor .max_num_tokens_allowed -= LORA_DICT [item .lora ]
97
+ llmactor .max_num_tokens_allowed -= LORA_DICT [item .lora ]
101
98
102
99
if item .start_prefill_time is None :
103
100
item .start_prefill_time = env .now
@@ -110,9 +107,9 @@ def process_prefill_items( llmactor, env, items_to_prefill, req_dict_prefill, re
110
107
else :
111
108
llmactor .decode_store .put (item )
112
109
if item .output_size_remaining <= 0 :
113
- if logging :
114
- print (f'llmactor { llmactor .id } { item .id } item.output_size_remaining { item .output_size_remaining } ' )
115
- assert item .output_size_remaining > 0
110
+ if logging :
111
+ print (f'llmactor { llmactor .id } { item .id } item.output_size_remaining { item .output_size_remaining } ' )
112
+ assert item .output_size_remaining > 0
116
113
req_dict_prefill [item .id ] = item
117
114
req_dict [item .id ] = item
118
115
return prefill_delay
@@ -121,13 +118,13 @@ def should_recompute(llmactor, env):
121
118
"""Determine if items should be moved to recompute based on memory usage."""
122
119
return llmactor .get_expected_num_tokens_in_kvcache_after_decode () / (llmactor .max_num_tokens_allowed + 0.0 ) > MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
123
120
124
- def remove_from_decode_store (llmactor , env , req_dict_prefill , req_dict , logging = False ):
121
+ def remove_from_decode_store (llmactor , env , req_dict_prefill , req_dict , logging = False ):
125
122
"""Manage the recomputation of items based on priority and conditions."""
126
123
while should_recompute (llmactor , env ):
127
124
if llmactor .get_decode_queue_size () > 0 :
128
125
newest_decode_item_id = llmactor .decode_store .items [- 1 ].id # newest item goes to recompute
129
126
if logging :
130
- print (f'llmactor { llmactor .id } removing from decode store sequence { newest_decode_item_id } ' )
127
+ print (f'llmactor { llmactor .id } removing from decode store sequence { newest_decode_item_id } ' )
131
128
req_dict [newest_decode_item_id ].recompute_count += 1
132
129
133
130
newest_decode_item = yield llmactor .decode_store .get (lambda req : req .id == newest_decode_item_id )
@@ -140,13 +137,13 @@ def decode_items(llmactor, env, req_dict_prefill, req_dict, logging=False):
140
137
temp_items = []
141
138
decode_delay = calculate_decode_delay (before_decoding_token_count , num_items_to_decode , TOKENIZE_LATENCY_CONST , DECODE_LATENCY_CONST_1 , DECODE_LATENCY_CONST_0 , DECODE_LATENCY_CONST_BATCH )
142
139
if logging :
143
- print (f'llmactor { llmactor .id } Decoding sequences { [x .id for x in llmactor .decode_store .items ]} items with delay { decode_delay } ' )
140
+ print (f'llmactor { llmactor .id } Decoding sequences { [x .id for x in llmactor .decode_store .items ]} items with delay { decode_delay } ' )
144
141
145
142
for _ in range (num_items_to_decode ):
146
143
msg = yield llmactor .decode_store .get ()
147
- if msg .output_size_remaining == msg .output_size - 1 :
148
- msg .start_decode_time = env .now
149
- msg .tokens_in_kv_cache_at_start_of_decode = before_decoding_token_count
144
+ if msg .output_size_remaining == msg .output_size - 1 :
145
+ msg .start_decode_time = env .now
146
+ msg .tokens_in_kv_cache_at_start_of_decode = before_decoding_token_count
150
147
msg .output_size_remaining -= 1
151
148
if msg .output_size_remaining < 0 :
152
149
raise ValueError (f'Output size remaining negative for { msg .id } ' )
@@ -155,57 +152,51 @@ def decode_items(llmactor, env, req_dict_prefill, req_dict, logging=False):
155
152
req_dict_prefill [msg .id ] = msg
156
153
req_dict [msg .id ] = msg
157
154
158
-
159
-
160
155
for item in temp_items :
161
156
if item .output_size_remaining == 0 :
162
157
item .end_decode_time = env .now + decode_delay
163
-
164
158
llmactor .decoded_store .put (item )
165
159
else :
166
160
item .end_decode_time = env .now + decode_delay
167
161
llmactor .decode_store .put (item )
168
162
169
163
return decode_delay
170
164
171
- def calculate_decode_delay (token_count , num_items_to_decode , tokenize_latency_const , decode_latency_const_1 , decode_latency_const_0 , decode_latency_const_batch ):
165
+ def calculate_decode_delay (token_count , num_items_to_decode , tokenize_latency_const , decode_latency_const_1 , decode_latency_const_0 , decode_latency_const_batch ):
172
166
"""Calculate delay based on the token count and latency constants."""
173
- return token_count * decode_latency_const_1 + decode_latency_const_0 + (tokenize_latency_const + decode_latency_const_batch )* num_items_to_decode
167
+ return token_count * decode_latency_const_1 + decode_latency_const_0 + (tokenize_latency_const + decode_latency_const_batch ) * num_items_to_decode
174
168
175
169
def calculate_prefill_delay (token_count , num_items_to_prefill , tokenize_latency_const , prefill_latency_const_2 , prefill_latency_const_1 , prefill_latency_const_0 , prefill_latency_const_min ):
176
170
"""Calculate delay based on the token count and latency constants."""
177
- return max (prefill_latency_const_min , (token_count * token_count * prefill_latency_const_2 + token_count * prefill_latency_const_1 + prefill_latency_const_0 + num_items_to_prefill * tokenize_latency_const ))
171
+ return max (prefill_latency_const_min , (token_count * token_count * prefill_latency_const_2 + token_count * prefill_latency_const_1 + prefill_latency_const_0 + num_items_to_prefill * tokenize_latency_const ))
178
172
179
- def prefill_or_decode (env , llmactor , req_dict_prefill , req_dict , logging = False ):
173
+ def prefill_or_decode (env , llmactor , req_dict_prefill , req_dict , logging = False ):
180
174
"""Main process for managing prefill, decode, or recompute operations."""
181
175
while True :
182
-
183
176
with llmactor .actor .request () as req :
184
-
185
177
yield req
186
178
if (llmactor .get_decode_queue_size () == 0 ) and (llmactor .get_prefill_queue_size () == 0 ) and (llmactor .get_recompute_queue_size () == 0 ):
187
- yield env .timeout (1 / 1000.0 )
179
+ yield env .timeout (1 / 1000.0 )
188
180
elif should_process_prefill_or_recompute (llmactor , env ):
189
181
items_to_prefill = yield from fetch_prefill_items (llmactor , env )
190
- prefill_delay = process_prefill_items ( llmactor , env ,items_to_prefill , req_dict_prefill , req_dict )
182
+ prefill_delay = process_prefill_items (llmactor , env , items_to_prefill , req_dict_prefill , req_dict )
191
183
if logging :
192
- print (f'llmactor { llmactor .id } Processed prefill for sequences { [x .id for x in items_to_prefill ]} with delay { prefill_delay } ' )
184
+ print (f'llmactor { llmactor .id } Processed prefill for sequences { [x .id for x in items_to_prefill ]} with delay { prefill_delay } ' )
193
185
yield env .timeout (prefill_delay ) # Assume prefill_delay is calculated somewhere
194
186
else :
195
- if should_recompute (llmactor , env ):
196
- yield from remove_from_decode_store (llmactor , env , req_dict_prefill , req_dict )
197
- if llmactor .get_decode_queue_size () > 0 :
198
- decode_delay = yield from decode_items (llmactor , env , req_dict_prefill , req_dict )
199
- yield env .timeout (decode_delay )
200
-
187
+ if should_recompute (llmactor , env ):
188
+ yield from remove_from_decode_store (llmactor , env , req_dict_prefill , req_dict )
189
+ if llmactor .get_decode_queue_size () > 0 :
190
+ decode_delay = yield from decode_items (llmactor , env , req_dict_prefill , req_dict )
191
+ yield env .timeout (decode_delay )
192
+
201
193
def metrics (env , llmactor ):
202
- while True :
203
- yield env .timeout (10 )
204
- cur_time = env .now
205
- num_of_prompt_tokens = llmactor .get_num_prompt_tokens_in_decode () + llmactor .get_num_prompt_tokens_in_decoded ()
206
- num_of_gen_tokens = llmactor .get_num_gen_tokens_in_decode () + llmactor .get_num_gen_tokens_in_decoded ()
207
- running_req = llmactor .get_decode_queue_size ()
208
- pending_req = llmactor .get_prefill_queue_size ()
209
- gpu_kv_cache_usage = llmactor .get_num_tokens_in_decode ()/ llmactor .max_num_tokens_allowed * 100
210
- print (f'llmactor { llmactor .id } Avg prompt throughput: { num_of_prompt_tokens / cur_time } tokens/s, Avg generation throughput: { num_of_gen_tokens / cur_time } , Running: { running_req } reqs, Pending: { pending_req } reqs, GPU KV cache usage: { gpu_kv_cache_usage } %' )
211
-
194
+ while True :
195
+ yield env .timeout (10 )
196
+ cur_time = env .now
197
+ num_of_prompt_tokens = llmactor .get_num_prompt_tokens_in_decode () + llmactor .get_num_prompt_tokens_in_decoded ()
198
+ num_of_gen_tokens = llmactor .get_num_gen_tokens_in_decode () + llmactor .get_num_gen_tokens_in_decoded ()
199
+ running_req = llmactor .get_decode_queue_size ()
200
+ pending_req = llmactor .get_prefill_queue_size ()
201
+ gpu_kv_cache_usage = llmactor .get_num_tokens_in_decode () / llmactor .max_num_tokens_allowed * 100
202
+ print (f'llmactor { llmactor .id } Avg prompt throughput: { num_of_prompt_tokens / cur_time } tokens/s, Avg generation throughput: { num_of_gen_tokens / cur_time } , Running: { running_req } reqs, Pending: { pending_req } reqs, GPU KV cache usage: { gpu_kv_cache_usage } %' )
0 commit comments