Skip to content

Commit 1e9e3cc

Browse files
committed
fix recompute bug
1 parent 1fdb7b5 commit 1e9e3cc

12 files changed

+386
-14
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
2+
3+
4+
5+
import argparse
6+
from collections import Counter
7+
import csv
8+
from datetime import datetime
9+
import numpy as np
10+
import simpy
11+
from llmactor import LLMActor
12+
from loadbalancer import LoadBalancer
13+
14+
def main():
15+
parser = argparse.ArgumentParser(description="Simulate LLM load balancing with configurable parameters.")
16+
parser.add_argument("--rates_lo", nargs='+', type=int, default=[35, 30, 25, 20, 15, 10, 5, 1], help="List of low rates.")
17+
parser.add_argument("--rates_hi", nargs='+', type=int, default=[35, 30, 25, 20, 15, 10, 5, 1], help="List of high rates.")
18+
parser.add_argument("--no_of_messages", type=int, default=2500, help="Number of messages to simulate.")
19+
parser.add_argument("--mean_request_size_1", type=int, default=202, help="Mean request size for set 1.")
20+
parser.add_argument("--std_request_size_1", type=int, default=20, help="Standard deviation of request size for set 1.")
21+
parser.add_argument("--mean_output_size_1", type=int, default=179, help="Mean output size for set 1.")
22+
parser.add_argument("--std_output_size_1", type=int, default=17, help="Standard deviation of output size for set 1.")
23+
parser.add_argument("--mean_request_size_2", type=int, default=202, help="Mean request size for set 2.")
24+
parser.add_argument("--std_request_size_2", type=int, default=20, help="Standard deviation of request size for set 2.")
25+
parser.add_argument("--mean_output_size_2", type=int, default=179, help="Mean output size for set 2.")
26+
parser.add_argument("--std_output_size_2", type=int, default=17, help="Standard deviation of output size for set 2.")
27+
parser.add_argument("--queueing_perc", type=float, default=np.inf, help="Queueing percentage.")
28+
parser.add_argument('--target-latency-lo', nargs='+', type=float, help='List of target latencies for low priority requests.')
29+
parser.add_argument('--target-latency-hi', nargs='+', type=float, help='List of target latencies for high priority requests.')
30+
31+
parser.add_argument('--prefix-latency-lo', nargs='+', type=float, help='List of prefix of target latencies for low priority requests.')
32+
parser.add_argument('--prefix-latency-hi', nargs='+', type=float, help='List of prefix of target latencies for high priority requests.')
33+
34+
35+
parser.add_argument('--number-of-servers', type=int, default=1, help='List of target latencies for high priority requests.')
36+
37+
args = parser.parse_args()
38+
39+
# Use provided arguments or defaults
40+
rates_lo = args.rates_lo
41+
rates_hi = args.rates_hi
42+
no_of_messages = args.no_of_messages
43+
SIM_DURATIONS = [no_of_messages / r + 100 for r in rates_lo]
44+
mean_request_size_1 = args.mean_request_size_1
45+
std_request_size_1 = args.std_request_size_1
46+
mean_output_size_1 = args.mean_output_size_1
47+
std_output_size_1 = args.std_output_size_1
48+
49+
mean_request_size_2 = args.mean_request_size_2
50+
std_request_size_2 = args.std_request_size_2
51+
mean_output_size_2 = args.mean_output_size_2
52+
std_output_size_2 = args.std_output_size_2
53+
54+
queueing_perc = args.queueing_perc
55+
lora_requested_lo = ""
56+
lora_requested_hi = ""
57+
58+
target_latency_list_lo = args.target_latency_lo if args.target_latency_lo else [0.025]
59+
target_latency_list_hi = args.target_latency_hi if args.target_latency_hi else [0.5]
60+
61+
prefix_latency_list_lo = args.prefix_latency_lo if args.prefix_latency_lo else ['lo']
62+
prefix_latency_list_hi = args.prefix_latency_hi if args.prefix_latency_hi else ['hi']
63+
64+
number_of_servers = args.number_of_servers
65+
66+
# Define a structure to store results for all routing types
67+
results = {
68+
'leastPseudo': {'latency': [], 'latency_lo': [], 'latency_hi': [],
69+
'throughput_prefill': [], 'throughput_decode': [],
70+
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
71+
'throughput_prefill_hi': [], 'throughput_decode_hi': [],
72+
'ttft': [], 'ttft_lo': [], 'ttft_hi': [],
73+
'tpot': [], 'tpot_lo': [], 'tpot_hi': [],
74+
'target_pods_lo': [], 'target_pods_hi': [],
75+
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
76+
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
77+
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
78+
'avg_prefill_queue_size' : [],
79+
'avg_pending_tokens_perc' : [],
80+
'avg_actual_tokens_perc' : []},
81+
82+
'smart': {'latency': [], 'latency_lo': [], 'latency_hi': [],
83+
'estimated_latency': [], 'estimated_latency_lo': [], 'estimated_latency_hi': [],
84+
'throughput_prefill': [], 'throughput_decode': [],
85+
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
86+
'throughput_prefill_hi': [], 'throughput_decode_hi': [],
87+
'ttft': [], 'ttft_lo': [], 'ttft_hi': [],
88+
'tpot': [], 'tpot_lo': [], 'tpot_hi': [],
89+
'target_pods_lo': [], 'target_pods_hi': [],
90+
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
91+
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [],
92+
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
93+
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
94+
'avg_prefill_queue_size' : [],
95+
'avg_pending_tokens_perc' : [],
96+
'avg_actual_tokens_perc' : []},
97+
98+
99+
'leastlatency': {'latency': [], 'latency_lo': [], 'latency_hi': [],
100+
'throughput_prefill': [], 'throughput_decode': [],
101+
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
102+
'throughput_prefill_hi': [], 'throughput_decode_hi': [],
103+
'ttft': [], 'ttft_lo': [], 'ttft_hi': [],
104+
'tpot': [], 'tpot_lo': [], 'tpot_hi': [],
105+
'target_pods_lo': [], 'target_pods_hi': [],
106+
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
107+
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
108+
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
109+
'avg_prefill_queue_size' : [],
110+
'avg_pending_tokens_perc' : [],
111+
'avg_actual_tokens_perc' : []},
112+
113+
'least': {'latency': [], 'latency_lo': [], 'latency_hi': [],
114+
'throughput_prefill': [], 'throughput_decode': [],
115+
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
116+
'throughput_prefill_hi': [], 'throughput_decode_hi': [],
117+
'ttft': [], 'ttft_lo': [], 'ttft_hi': [],
118+
'tpot': [], 'tpot_lo': [], 'tpot_hi': [],
119+
'target_pods_lo': [], 'target_pods_hi': [],
120+
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
121+
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
122+
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
123+
'avg_prefill_queue_size' : [],
124+
'avg_pending_tokens_perc' : [],
125+
'avg_actual_tokens_perc' : []},
126+
127+
'random': {'latency': [], 'latency_lo': [], 'latency_hi': [],
128+
'throughput_prefill': [], 'throughput_decode': [],
129+
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
130+
'throughput_prefill_hi': [], 'throughput_decode_hi': [],
131+
'ttft': [], 'ttft_lo': [], 'ttft_hi': [],
132+
'tpot': [], 'tpot_lo': [], 'tpot_hi': [],
133+
'target_pods_lo': [], 'target_pods_hi': [],
134+
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
135+
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
136+
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
137+
'avg_prefill_queue_size' : [],
138+
'avg_pending_tokens_perc' : [],
139+
'avg_actual_tokens_perc' : []},
140+
141+
}
142+
143+
all_routing_types = [ "random", ]
144+
prompt_output_tuple = None
145+
146+
# Iterate over routing types
147+
for routing_type in all_routing_types:
148+
print(f'Routing Type: {routing_type}')
149+
150+
for i, _ in enumerate(rates_lo):
151+
req_dict = {}
152+
req_dict_prefill = {}
153+
SIM_DURATION = SIM_DURATIONS[i]
154+
print(f'Simulate with rate: for lo {rates_lo[i]} and for hi {rates_hi[i]} and routing type: {routing_type}')
155+
156+
# Simpy environment and LLM actors setup
157+
env = simpy.Environment()
158+
list_of_llmactors = [LLMActor(env, 1, id) for id in range(number_of_servers)]
159+
lb = LoadBalancer(env, number_of_servers=number_of_servers, list_of_llmactors=list_of_llmactors, req_dict_prefill=req_dict_prefill, req_dict=req_dict, messages_remaining_cnt=no_of_messages)
160+
lb.queueing_perc = queueing_perc
161+
162+
estimated_output_size = mean_output_size_1
163+
lb.process(rates_lo[i], lora_requested_lo, target_latency_list_lo, prefix_latency_list_lo, routing_type, prompt_output_tuple, mean_request_size_1, std_request_size_1, mean_output_size_1, std_output_size_1, estimated_output_size)
164+
env.run(until=SIM_DURATION)
165+
166+
167+
168+
# Completed requests
169+
completed_req = list(filter(lambda x: x.output_size_remaining == 0, req_dict.values()))
170+
171+
172+
completed_req_sorted = sorted(completed_req, key=lambda x: x.arrival_time)
173+
174+
175+
# Exclude the first 10% of requests based on end_decode_time
176+
exclude_count = int(0 * len(completed_req_sorted))
177+
178+
179+
# Filter out the first 10%
180+
filtered_req = completed_req_sorted[exclude_count:]
181+
182+
183+
184+
# Calculate ttft, tpot, latency, and throughput
185+
ttft_cur = np.mean([x.end_prefill_time - x.arrival_time for x in req_dict.values()])
186+
187+
188+
tpot_cur = np.mean([(x.end_decode_time - x.start_prefill_time) / (x.output_size - x.output_size_remaining) for x in req_dict.values()])
189+
190+
latency_cur = np.mean([(x.end_decode_time - x.arrival_time) / (x.output_size - x.output_size_remaining) for x in filtered_req])
191+
192+
estimated_latency_cur = np.mean([x.estimated_latency for x in filtered_req])
193+
194+
recompute_cur = np.sum([x.recompute_count for x in filtered_req]) / len(filtered_req)
195+
196+
tt = SIM_DURATION
197+
throughput_prefill_cur = np.sum([x.input_size for x in filtered_req]) / tt
198+
throughput_decode_cur = np.sum([max(0, x.output_size - x.output_size_remaining - 1) for x in filtered_req]) / tt
199+
200+
201+
202+
203+
pending_tokens_at_arrival_perc = [x.pending_tokens_at_arrival_perc for x in completed_req]
204+
actual_tokens_at_arrival_perc = [x.actual_tokens_at_arrival_perc for x in completed_req]
205+
prefill_queue_size = [x.queue_size_before_prefill for x in completed_req]
206+
207+
# Store results for the current routing type
208+
results[routing_type]['latency'].append(latency_cur)
209+
results[routing_type]['throughput_prefill'].append(throughput_prefill_cur)
210+
results[routing_type]['throughput_decode'].append(throughput_decode_cur)
211+
results[routing_type]['ttft'].append(ttft_cur)
212+
results[routing_type]['tpot'].append(tpot_cur)
213+
214+
215+
results[routing_type]['recompute_cnt'].append(recompute_cur)
216+
217+
218+
219+
results[routing_type]['avg_prefill_queue_size'].append(np.mean(prefill_queue_size))
220+
results[routing_type]['avg_pending_tokens_perc'].append(np.mean(pending_tokens_at_arrival_perc))
221+
results[routing_type]['avg_actual_tokens_perc'].append(np.mean(actual_tokens_at_arrival_perc))
222+
223+
224+
225+
226+
# Create a timestamp
227+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
228+
229+
# Create the output file name with the timestamp
230+
output_file = f"results_{timestamp}.csv"
231+
232+
233+
234+
# Write results to CSV
235+
with open(output_file, 'w', newline='') as csvfile:
236+
fieldnames = ['RoutingType', 'RateIndex', 'Latency', 'avg_prefill_queue_size', 'avg_pending_tokens_perc', 'avg_actual_tokens_perc' ]
237+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
238+
239+
writer.writeheader()
240+
241+
# Iterate over routing types and write each entry
242+
for routing_type in all_routing_types:
243+
for i in range(len(rates_lo)):
244+
writer.writerow({
245+
'RoutingType': routing_type,
246+
'RateIndex': rates_lo[i],
247+
'Latency': results[routing_type]['latency'][i],
248+
'avg_prefill_queue_size': results[routing_type]['avg_prefill_queue_size'][i],
249+
'avg_pending_tokens_perc': results[routing_type]['avg_pending_tokens_perc'][i],
250+
'avg_actual_tokens_perc': results[routing_type]['avg_actual_tokens_perc'][i],
251+
})
252+
253+
print(f"Results have been saved to {output_file}")
254+
255+
256+
257+
if __name__ == "__main__":
258+
main()

simulations/llm_ig_simulation/continous_batching.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ def can_prefill_items(llmactor, env, ):
1515
while llmactor.get_recompute_queue_size() > 0:
1616
oldest_item = llmactor.recompute_store.items[0].item
1717
oldest_item_len = oldest_item.input_size + oldest_item.output_size - oldest_item.output_size_remaining
18+
oldest_item_input_len = oldest_item.input_size
1819

1920
if any([
2021
llmactor.get_decode_queue_size() + num_new_seq + 1 > MAX_NUM_SEQ,
21-
prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS,
22+
prefill_batch_size + oldest_item_input_len > MAX_NUM_BATCH_TOKENS,
2223
(prefill_batch_size + num_new_seq + llmactor.get_num_tokens_in_decode()) / (llmactor.max_num_tokens_allowed + 0.0) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
2324
]):
2425
break
@@ -28,10 +29,11 @@ def can_prefill_items(llmactor, env, ):
2829

2930
oldest_item = llmactor.prefill_store.items[0]
3031
oldest_item_len = oldest_item.input_size + oldest_item.output_size - oldest_item.output_size_remaining
32+
oldest_item_input_len = oldest_item.input_size
3133

3234
if any([
3335
llmactor.get_decode_queue_size() + num_new_seq + 1 > MAX_NUM_SEQ,
34-
prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS,
36+
prefill_batch_size + oldest_item_input_len > MAX_NUM_BATCH_TOKENS,
3537
(prefill_batch_size + num_new_seq + llmactor.get_num_tokens_in_decode()) / (llmactor.max_num_tokens_allowed + 0.0) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
3638
]):
3739
break
@@ -50,10 +52,11 @@ def fetch_prefill_items(llmactor, env, ):
5052
while llmactor.get_recompute_queue_size() > 0:
5153
oldest_item = llmactor.recompute_store.items[0].item
5254
oldest_item_len = oldest_item.input_size + oldest_item.output_size - oldest_item.output_size_remaining
55+
oldest_item_input_len = oldest_item.input_size
5356

5457
if any([
5558
llmactor.get_decode_queue_size() + num_new_seq + 1 > MAX_NUM_SEQ,
56-
prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS,
59+
prefill_batch_size + oldest_item_input_len > MAX_NUM_BATCH_TOKENS,
5760
(prefill_batch_size + num_new_seq + llmactor.get_num_tokens_in_decode()) / (llmactor.max_num_tokens_allowed + 0.0) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
5861
]):
5962
break
@@ -66,10 +69,11 @@ def fetch_prefill_items(llmactor, env, ):
6669
while llmactor.get_prefill_queue_size() > 0:
6770
oldest_item = llmactor.prefill_store.items[0]
6871
oldest_item_len = oldest_item.input_size + oldest_item.output_size - oldest_item.output_size_remaining
72+
oldest_item_input_len = oldest_item.input_size
6973

7074
if any([
7175
llmactor.get_decode_queue_size() + num_new_seq + 1 > MAX_NUM_SEQ,
72-
prefill_batch_size + oldest_item_len > MAX_NUM_BATCH_TOKENS,
76+
prefill_batch_size + oldest_item_input_len > MAX_NUM_BATCH_TOKENS,
7377
(prefill_batch_size + num_new_seq + llmactor.get_num_tokens_in_decode()) / (llmactor.max_num_tokens_allowed + 0.0) >= MAX_GPU_MEMORY_PERC_BEFORE_RECOMPUTE
7478
]):
7579
break

0 commit comments

Comments
 (0)