1
+ from constants import MAX_NUM_SEQ , MAX_NUM_BATCH_TOKENS , MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE , TOKENIZE_LATENCY_CONST , PREFILL_LATENCY_CONST_2 , PREFILL_LATENCY_CONST_1 , PREFILL_LATENCY_CONST_0 , PREFILL_LATENCY_CONST_MIN , DECODE_LATENCY_CONST_1 , DECODE_LATENCY_CONST_0 , DECODE_LATENCY_CONST_BATCH , LORA_DICT
2
+
3
+ import simpy
4
+ import numpy as np
5
+
6
+ def should_process_prefill_or_recompute (llmactor , env ):
7
+ """Check if the system should process prefill or recompute based on queue sizes and memory constraints."""
8
+ return can_prefill_items (llmactor , env )
9
+
10
+ def can_prefill_items (llmactor , env , ):
11
+ """Are there items I can prefill?"""
12
+ prefill_batch_size = 0
13
+ num_new_seq = 0
14
+
15
+ while llmactor .get_recompute_queue_size () > 0 :
16
+ oldest_item = llmactor .recompute_store .items [0 ].item
17
+ oldest_item_len = oldest_item .input_size + oldest_item .output_size - oldest_item .output_size_remaining
18
+
19
+ if any ([
20
+ llmactor .get_decode_queue_size () + num_new_seq + 1 > MAX_NUM_SEQ ,
21
+ prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS ,
22
+ (prefill_batch_size + num_new_seq + llmactor .get_num_tokens_in_decode ()) / (llmactor .max_num_tokens_allowed + 0.0 ) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
23
+ ]):
24
+ break
25
+
26
+ return True
27
+ while llmactor .get_prefill_queue_size () > 0 :
28
+
29
+ oldest_item = llmactor .prefill_store .items [0 ]
30
+ oldest_item_len = oldest_item .input_size + oldest_item .output_size - oldest_item .output_size_remaining
31
+
32
+ if any ([
33
+ llmactor .get_decode_queue_size () + num_new_seq + 1 > MAX_NUM_SEQ ,
34
+ prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS ,
35
+ (prefill_batch_size + num_new_seq + llmactor .get_num_tokens_in_decode ()) / (llmactor .max_num_tokens_allowed + 0.0 ) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
36
+ ]):
37
+ break
38
+
39
+ return True
40
+
41
+ return False
42
+
43
+
44
+ def fetch_prefill_items (llmactor , env , ):
45
+ """Fetch items to prefill if there is memory either from recompute (p0) or from prefill (p1)"""
46
+ items_to_prefill = []
47
+ prefill_batch_size = 0
48
+ num_new_seq = 0
49
+
50
+ while llmactor .get_recompute_queue_size () > 0 :
51
+ oldest_item = llmactor .recompute_store .items [0 ].item
52
+ oldest_item_len = oldest_item .input_size + oldest_item .output_size - oldest_item .output_size_remaining
53
+
54
+ if any ([
55
+ llmactor .get_decode_queue_size () + num_new_seq + 1 > MAX_NUM_SEQ ,
56
+ prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS ,
57
+ (prefill_batch_size + num_new_seq + llmactor .get_num_tokens_in_decode ()) / (llmactor .max_num_tokens_allowed + 0.0 ) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
58
+ ]):
59
+ break
60
+
61
+ prefill_batch_size += oldest_item_len
62
+ num_new_seq += 1
63
+ msg = yield llmactor .recompute_store .get ()
64
+ items_to_prefill .append (msg .item )
65
+
66
+ while llmactor .get_prefill_queue_size () > 0 :
67
+ oldest_item = llmactor .prefill_store .items [0 ]
68
+ oldest_item_len = oldest_item .input_size + oldest_item .output_size - oldest_item .output_size_remaining
69
+
70
+ if any ([
71
+ llmactor .get_decode_queue_size () + num_new_seq + 1 > MAX_NUM_SEQ ,
72
+ prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS ,
73
+ (prefill_batch_size + num_new_seq + llmactor .get_num_tokens_in_decode ()) / (llmactor .max_num_tokens_allowed + 0.0 ) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
74
+ ]):
75
+ break
76
+
77
+ prefill_batch_size += oldest_item_len
78
+ num_new_seq += 1
79
+ msg = yield llmactor .prefill_store .get ()
80
+ items_to_prefill .append (msg )
81
+
82
+
83
+ return items_to_prefill
84
+
85
+ def process_prefill_items ( llmactor , env , items_to_prefill , req_dict_prefill , req_dict , logging = False ):
86
+ """Process prefill items, updating times and managing item states."""
87
+ prefill_len = np .sum ([x .input_size + x .output_size - x .output_size_remaining for x in items_to_prefill ])
88
+ prefill_delay = calculate_prefill_delay (prefill_len , len (items_to_prefill ), TOKENIZE_LATENCY_CONST , PREFILL_LATENCY_CONST_2 , PREFILL_LATENCY_CONST_1 , PREFILL_LATENCY_CONST_0 , PREFILL_LATENCY_CONST_MIN )
89
+
90
+
91
+ for item in items_to_prefill :
92
+ #lora stuff
93
+ if item .lora is not None :
94
+ if item .lora not in llmactor .lora_loaded :
95
+ llmactor .lora_loaded .add (item .lora )
96
+ llmactor .max_num_tokens_allowed -= LORA_DICT [item .lora ]
97
+
98
+ if item .start_prefill_time is None :
99
+ item .start_prefill_time = env .now
100
+ item .end_prefill_time = item .start_prefill_time + prefill_delay
101
+ item .end_decode_time = llmactor .env .now + prefill_delay
102
+ item .output_size_remaining -= 1
103
+
104
+ if item .output_size_remaining == 0 :
105
+ llmactor .decoded_store .put (item )
106
+ else :
107
+ llmactor .decode_store .put (item )
108
+ if item .output_size_remaining <= 0 :
109
+ if logging :
110
+ print (f'llmactor { llmactor .id } { item .id } item.output_size_remaining { item .output_size_remaining } ' )
111
+ assert item .output_size_remaining > 0
112
+ req_dict_prefill [item .id ] = item
113
+ req_dict [item .id ] = item
114
+ return prefill_delay
115
+
116
+ def should_recompute (llmactor , env ):
117
+ """Determine if items should be moved to recompute based on memory usage."""
118
+ return llmactor .get_expected_num_tokens_in_kvcache_after_decode () / (llmactor .max_num_tokens_allowed + 0.0 ) > MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
119
+
120
+ def remove_from_decode_store (llmactor , env , req_dict_prefill , req_dict , logging = False ):
121
+ """Manage the recomputation of items based on priority and conditions."""
122
+ while should_recompute (llmactor , env ):
123
+ if llmactor .get_decode_queue_size () > 0 :
124
+ newest_decode_item_id = llmactor .decode_store .items [- 1 ].id # newest item goes to recompute
125
+ if logging :
126
+ print (f'llmactor { llmactor .id } removing from decode store sequence { newest_decode_item_id } ' )
127
+ req_dict [newest_decode_item_id ].recompute_count += 1
128
+
129
+ newest_decode_item = yield llmactor .decode_store .get (lambda req : req .id == newest_decode_item_id )
130
+ llmactor .recompute_store .put (simpy .PriorityItem (item = newest_decode_item , priority = newest_decode_item_id ))
131
+
132
+ def decode_items (llmactor , env , req_dict_prefill , req_dict , logging = False ):
133
+ """Process decoding of items, handling them appropriately based on their remaining output size."""
134
+ num_items_to_decode = llmactor .get_decode_queue_size ()
135
+ before_decoding_token_count = llmactor .get_num_tokens_in_decode ()
136
+ temp_items = []
137
+ decode_delay = calculate_decode_delay (before_decoding_token_count , num_items_to_decode , TOKENIZE_LATENCY_CONST , DECODE_LATENCY_CONST_1 , DECODE_LATENCY_CONST_0 , DECODE_LATENCY_CONST_BATCH )
138
+ if logging :
139
+ print (f'llmactor { llmactor .id } Decoding sequences { [x .id for x in llmactor .decode_store .items ]} items with delay { decode_delay } ' )
140
+
141
+ for _ in range (num_items_to_decode ):
142
+ msg = yield llmactor .decode_store .get ()
143
+ if msg .output_size_remaining == msg .output_size - 1 :
144
+ msg .start_decode_time = env .now
145
+ msg .tokens_in_kv_cache_at_start_of_decode = before_decoding_token_count
146
+ msg .output_size_remaining -= 1
147
+ if msg .output_size_remaining < 0 :
148
+ raise ValueError (f'Output size remaining negative for { msg .id } ' )
149
+
150
+ temp_items .append (msg )
151
+ req_dict_prefill [msg .id ] = msg
152
+ req_dict [msg .id ] = msg
153
+
154
+
155
+
156
+ for item in temp_items :
157
+ if item .output_size_remaining == 0 :
158
+ item .end_decode_time = env .now + decode_delay
159
+
160
+ llmactor .decoded_store .put (item )
161
+ else :
162
+ item .end_decode_time = env .now + decode_delay
163
+ llmactor .decode_store .put (item )
164
+
165
+ return decode_delay
166
+
167
+ def calculate_decode_delay (token_count , num_items_to_decode , tokenize_latency_const , decode_latency_const_1 , decode_latency_const_0 , decode_latency_const_batch ):
168
+ """Calculate delay based on the token count and latency constants."""
169
+ return token_count * decode_latency_const_1 + decode_latency_const_0 + (tokenize_latency_const + decode_latency_const_batch )* num_items_to_decode
170
+
171
+ def calculate_prefill_delay (token_count , num_items_to_prefill , tokenize_latency_const , prefill_latency_const_2 , prefill_latency_const_1 , prefill_latency_const_0 , prefill_latency_const_min ):
172
+ """Calculate delay based on the token count and latency constants."""
173
+ return max (prefill_latency_const_min , (token_count * token_count * prefill_latency_const_2 + token_count * prefill_latency_const_1 + prefill_latency_const_0 + num_items_to_prefill * tokenize_latency_const ))
174
+
175
+ def prefill_or_decode (env , llmactor , req_dict_prefill , req_dict , logging = False ):
176
+ """Main process for managing prefill, decode, or recompute operations."""
177
+ while True :
178
+
179
+ with llmactor .actor .request () as req :
180
+
181
+ yield req
182
+ if (llmactor .get_decode_queue_size () == 0 ) and (llmactor .get_prefill_queue_size () == 0 ) and (llmactor .get_recompute_queue_size () == 0 ):
183
+ yield env .timeout (1 / 1000.0 )
184
+ elif should_process_prefill_or_recompute (llmactor , env ):
185
+ items_to_prefill = yield from fetch_prefill_items (llmactor , env )
186
+ prefill_delay = process_prefill_items ( llmactor , env ,items_to_prefill , req_dict_prefill , req_dict )
187
+ if logging :
188
+ print (f'llmactor { llmactor .id } Processed prefill for sequences { [x .id for x in items_to_prefill ]} with delay { prefill_delay } ' )
189
+ yield env .timeout (prefill_delay ) # Assume prefill_delay is calculated somewhere
190
+ else :
191
+ if should_recompute (llmactor , env ):
192
+ yield from remove_from_decode_store (llmactor , env , req_dict_prefill , req_dict )
193
+ if llmactor .get_decode_queue_size () > 0 :
194
+ decode_delay = yield from decode_items (llmactor , env , req_dict_prefill , req_dict )
195
+ yield env .timeout (decode_delay )
196
+
197
+ def metrics (env , llmactor ):
198
+ while True :
199
+ yield env .timeout (10 )
200
+ cur_time = env .now
201
+ num_of_prompt_tokens = llmactor .get_num_prompt_tokens_in_decode () + llmactor .get_num_prompt_tokens_in_decoded ()
202
+ num_of_gen_tokens = llmactor .get_num_gen_tokens_in_decode () + llmactor .get_num_gen_tokens_in_decoded ()
203
+ running_req = llmactor .get_decode_queue_size ()
204
+ pending_req = llmactor .get_prefill_queue_size ()
205
+ gpu_kv_cache_usage = llmactor .get_num_tokens_in_decode ()/ llmactor .max_num_tokens_allowed * 100
206
+ print (f'llmactor { llmactor .id } Avg prompt throughput: { num_of_prompt_tokens / cur_time } tokens/s, Avg generation throughput: { num_of_gen_tokens / cur_time } , Running: { running_req } reqs, Pending: { pending_req } reqs, GPU KV cache usage: { gpu_kv_cache_usage } %' )
207
+
0 commit comments