Skip to content

Commit 33f20a9

Browse files
committed
move to src folder, do weighted dequeing
1 parent 9bc831c commit 33f20a9

File tree

8 files changed

+95
-34
lines changed

8 files changed

+95
-34
lines changed

simulations/llm_ig_simulation/loadbalancer.py renamed to simulations/llm_ig_simulation/src/loadbalancer.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self,
1818
list_of_llmactors: List[LLMActor] = None,
1919
messages_remaining_cnt: Optional[int] = None,
2020
req_dict_prefill = {}, req_dict = {},
21-
queueing_perc: float = 0.2):
21+
queueing_perc: float = np.inf):
2222
self.number_of_servers = number_of_servers
2323
self.list_of_llmactors = list_of_llmactors or []
2424
assert len(self.list_of_llmactors) == number_of_servers, "Number of actors must match number of servers"
@@ -350,13 +350,13 @@ def find_target_pod(self, routing_type, input_size, output_size, target_latency
350350

351351
def queueing_signal(self, routing_type = "smart") -> bool:
352352
if routing_type == "smart":
353-
return self.check_saturations(use_pseudo_kv_cache=True, max_saturation= self.queueing_perc) or self.all_servers_queued()
353+
return self.check_saturations(use_pseudo_kv_cache=False, max_saturation= self.queueing_perc) or self.all_servers_queued()
354354
else :
355355
return self.get_overall_pending_tokens_perc() > self.queueing_perc or self.all_servers_queued()
356356

357357
def dequeueing_signal(self, routing_type = "smart") -> bool:
358358
if routing_type == "smart":
359-
return self.check_saturations(use_pseudo_kv_cache=True, max_saturation= self.queueing_perc) == False and self.all_servers_queued() == False
359+
return self.check_saturations(use_pseudo_kv_cache=False, max_saturation= self.queueing_perc) == False and self.all_servers_queued() == False
360360
else :
361361
return self.get_overall_pending_tokens_perc() < self.queueing_perc and self.all_servers_queued() == False
362362

@@ -369,6 +369,33 @@ def check_if_queues_empty(self) -> bool:
369369
return False
370370
return True
371371

372+
import random
373+
374+
def weighted_dequeue(self) -> Optional[Request]:
375+
# Get active targets and their latencies
376+
active_targets = list(self.getActiveReqTargetLatencyInWindow(np.inf))
377+
378+
# Calculate inverse weights based on latencies
379+
inverse_weights = {k: 1.0 / k for k in active_targets}
380+
381+
# Calculate total weight to normalize
382+
total_weight = sum(inverse_weights.values())
383+
384+
# Calculate the relative probabilities for each target
385+
target_probs = {k: inverse_weights[k] / total_weight for k in active_targets}
386+
387+
# Use random.choices to select a target based on probabilities
388+
# Attempt to dequeue from the selected target's queue
389+
for _ in range(100): # Try up to the 100 times
390+
selected_target = random.choices(list(target_probs.keys()), weights=target_probs.values(), k=1)[0]
391+
392+
# Check if the selected target's queue is non-empty
393+
if selected_target in self.queues and not self.queues[selected_target].empty():
394+
req = self.queues[selected_target].get()
395+
return req
396+
397+
return None
398+
372399
def dequeue(self) -> Optional[Request]:
373400
active_targets = sorted(self.getActiveReqTargetLatencyInWindow(np.inf))
374401
for k in active_targets:
@@ -378,6 +405,8 @@ def dequeue(self) -> Optional[Request]:
378405
return None
379406

380407

408+
409+
381410
def dequeue_process(self, routing_type, drop_late_requests = False):
382411
while True:
383412
if not self.check_if_queues_empty() and self.dequeueing_signal(routing_type):
@@ -512,7 +541,7 @@ def generate_request_inference_gateway(
512541
prompt_output_tuple, mean_request_size, std_request_size,
513542
mean_output_size, std_output_size
514543
)
515-
output_size = min(output_size, MAX_NUM_BATCH_TOKENS)
544+
input_size = min(input_size, MAX_NUM_BATCH_TOKENS)
516545

517546
request_id = f"{prefix}: {cnt}"
518547
new_req = self.create_request(request_id, input_size, output_size, target_latency)

simulations/llm_ig_simulation/main.py renamed to simulations/llm_ig_simulation/src/main.py

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@
22
from collections import Counter
33
import csv
44
from datetime import datetime
5+
import os
6+
import string
57
import numpy as np
68
import simpy
79
from llmactor import LLMActor
810
from loadbalancer import LoadBalancer
11+
import sys
912

1013
def main():
14+
15+
1116
parser = argparse.ArgumentParser(description="Simulate LLM load balancing with configurable parameters.")
12-
parser.add_argument("--rates_lo", nargs='+', type=int, default=[35, 30, 25, 20, 15, 10, 5, 1], help="List of low rates.")
13-
parser.add_argument("--rates_hi", nargs='+', type=int, default=[35, 30, 25, 20, 15, 10, 5, 1], help="List of high rates.")
17+
parser.add_argument("--rates_lo", nargs='+', type=int, default=[40,35, 30, 25, 20, 15, 10, 5, 1], help="List of low rates.")
18+
parser.add_argument("--rates_hi", nargs='+', type=int, default=[40,35, 30, 25, 20, 15, 10, 5, 1 ], help="List of high rates.")
1419
parser.add_argument("--no_of_messages", type=int, default=2500, help="Number of messages to simulate.")
1520
parser.add_argument("--mean_request_size_1", type=int, default=202, help="Mean request size for set 1.")
1621
parser.add_argument("--std_request_size_1", type=int, default=20, help="Standard deviation of request size for set 1.")
@@ -20,7 +25,8 @@ def main():
2025
parser.add_argument("--std_request_size_2", type=int, default=20, help="Standard deviation of request size for set 2.")
2126
parser.add_argument("--mean_output_size_2", type=int, default=179, help="Mean output size for set 2.")
2227
parser.add_argument("--std_output_size_2", type=int, default=17, help="Standard deviation of output size for set 2.")
23-
parser.add_argument("--queueing_perc", type=float, default=0.19, help="Queueing percentage.")
28+
parser.add_argument("--estimated_output_size", type=str, default="mean", help="how to determine the mean output size.")
29+
parser.add_argument("--queueing_perc", type=float, default=np.inf, help="Queueing percentage.")
2430
parser.add_argument('--target-latency-lo', nargs='+', type=float, help='List of target latencies for low priority requests.')
2531
parser.add_argument('--target-latency-hi', nargs='+', type=float, help='List of target latencies for high priority requests.')
2632

@@ -29,6 +35,8 @@ def main():
2935

3036

3137
parser.add_argument('--number-of-servers', type=int, default=6, help='List of target latencies for high priority requests.')
38+
parser.add_argument('--output-file', type=str, default="result.csv", help='output file name.')
39+
parser.add_argument('--routing-type', type=str, default="random", help='routing type')
3240

3341
args = parser.parse_args()
3442

@@ -58,6 +66,9 @@ def main():
5866
prefix_latency_list_hi = args.prefix_latency_hi if args.prefix_latency_hi else ['hi']
5967

6068
number_of_servers = args.number_of_servers
69+
output_file = args.output_file
70+
routing_type = args.routing_type
71+
estimated_output_size = args.estimated_output_size
6172

6273
# Define a structure to store results for all routing types
6374
results = {
@@ -71,9 +82,9 @@ def main():
7182
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
7283
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
7384
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
74-
'avg_prefill_queue_size' = [],
75-
'avg_pending_tokens_perc' = [],
76-
'avg_actual_tokens_perc' = []},
85+
'avg_prefill_queue_size' : [],
86+
'avg_pending_tokens_perc' : [],
87+
'avg_actual_tokens_perc' : [], 'request_count': []},
7788

7889
'smart': {'latency': [], 'latency_lo': [], 'latency_hi': [],
7990
'estimated_latency': [], 'estimated_latency_lo': [], 'estimated_latency_hi': [],
@@ -87,9 +98,9 @@ def main():
8798
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [],
8899
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
89100
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
90-
'avg_prefill_queue_size' = [],
91-
'avg_pending_tokens_perc' = [],
92-
'avg_actual_tokens_perc' = []},
101+
'avg_prefill_queue_size' : [],
102+
'avg_pending_tokens_perc' : [],
103+
'avg_actual_tokens_perc' : [], 'request_count': []},
93104

94105
'leastlatency': {'latency': [], 'latency_lo': [], 'latency_hi': [],
95106
'throughput_prefill': [], 'throughput_decode': [],
@@ -101,9 +112,9 @@ def main():
101112
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
102113
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
103114
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
104-
'avg_prefill_queue_size' = [],
105-
'avg_pending_tokens_perc' = [],
106-
'avg_actual_tokens_perc' = []},
115+
'avg_prefill_queue_size' : [],
116+
'avg_pending_tokens_perc' : [],
117+
'avg_actual_tokens_perc' : [], 'request_count': []},
107118
'least': {'latency': [], 'latency_lo': [], 'latency_hi': [],
108119
'throughput_prefill': [], 'throughput_decode': [],
109120
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
@@ -114,9 +125,9 @@ def main():
114125
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
115126
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
116127
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
117-
'avg_prefill_queue_size' = [],
118-
'avg_pending_tokens_perc' = [],
119-
'avg_actual_tokens_perc' = []},
128+
'avg_prefill_queue_size' : [],
129+
'avg_pending_tokens_perc' : [],
130+
'avg_actual_tokens_perc' : [], 'request_count': []},
120131
'random': {'latency': [], 'latency_lo': [], 'latency_hi': [],
121132
'throughput_prefill': [], 'throughput_decode': [],
122133
'throughput_prefill_lo': [], 'throughput_decode_lo': [],
@@ -127,12 +138,12 @@ def main():
127138
'recompute_cnt' : [], 'recompute_cnt_hi' : [], 'recompute_cnt_lo' : [],
128139
'pct_below_latency_target_lo': [], 'pct_below_latency_target_hi': [], 'queue_time_lo': [], 'queue_time_hi': [],
129140
'tol_lat_time_lo': [], 'tol_lat_time_hi': [],
130-
'avg_prefill_queue_size' = [],
131-
'avg_pending_tokens_perc' = [],
132-
'avg_actual_tokens_perc' = []},
141+
'avg_prefill_queue_size' : [],
142+
'avg_pending_tokens_perc' : [],
143+
'avg_actual_tokens_perc' : [], 'request_count': []},
133144
}
134145

135-
all_routing_types = [ "random", ]
146+
all_routing_types = [ routing_type ]
136147
prompt_output_tuple = None
137148

138149
# Iterate over routing types
@@ -144,15 +155,23 @@ def main():
144155
req_dict_prefill = {}
145156
SIM_DURATION = SIM_DURATIONS[i]
146157
print(f'Simulate with rate: for lo {rates_lo[i]} and for hi {rates_hi[i]} and routing type: {routing_type}')
147-
158+
sys.stdout.flush()
148159
# Simpy environment and LLM actors setup
149160
env = simpy.Environment()
150161
list_of_llmactors = [LLMActor(env, 1, id) for id in range(number_of_servers)]
151162
lb = LoadBalancer(env, number_of_servers=number_of_servers, list_of_llmactors=list_of_llmactors, req_dict_prefill=req_dict_prefill, req_dict=req_dict, messages_remaining_cnt=no_of_messages*2)
152163
lb.queueing_perc = queueing_perc
153164

154-
estimated_output_size = mean_output_size_1
155-
lb.process(rates_lo[i], lora_requested_lo, target_latency_list_lo, prefix_latency_list_lo, routing_type, prompt_output_tuple, mean_request_size_1, std_request_size_1, mean_output_size_1, std_output_size_1, estimated_output_size)
165+
if estimated_output_size == "mean":
166+
estimated_output_size_1 = mean_output_size_1
167+
estimated_output_size_2 = mean_output_size_2
168+
elif estimated_output_size == "p95":
169+
estimated_output_size_1 = mean_output_size_1 + 2 * std_output_size_1
170+
estimated_output_size_2 = mean_output_size_2 + 2 * std_output_size_2
171+
172+
173+
lb.process(rates_lo[i], lora_requested_lo, target_latency_list_lo, prefix_latency_list_lo, routing_type, prompt_output_tuple, mean_request_size_1, std_request_size_1, mean_output_size_1, std_output_size_1, estimated_output_size_1)
174+
lb.process(rates_hi[i], lora_requested_hi, target_latency_list_hi, prefix_latency_list_hi, routing_type, prompt_output_tuple, mean_request_size_1, std_request_size_1, mean_output_size_1, std_output_size_1, estimated_output_size_2)
156175
env.run(until=SIM_DURATION)
157176

158177
# Track which pod processed each request (lo and hi)
@@ -268,11 +287,14 @@ def main():
268287
results[routing_type]['avg_prefill_queue_size'].append(np.mean(prefill_queue_size))
269288
results[routing_type]['avg_pending_tokens_perc'].append(np.mean(pending_tokens_at_arrival_perc))
270289
results[routing_type]['avg_actual_tokens_perc'].append(np.mean(actual_tokens_at_arrival_perc))
290+
271291

272-
l1 = [np.sum(list(dict(x).values())) for x in results[routing_type]['target_pods_lo']]
273-
l2 = [np.sum(list(dict(x).values())) for x in results[routing_type]['target_pods_hi']]
292+
l1 = [np.sum(list(dict(x).values())) for x in results[routing_type]['target_pods_lo']][-1]
293+
l2 = [np.sum(list(dict(x).values())) for x in results[routing_type]['target_pods_hi']][-1]
274294

275-
print(f'req count {[(l1[i], l2[i]) for i in range(len(l1))]}')
295+
print(f'req count {(l1, l2)}')
296+
sys.stdout.flush()
297+
results[routing_type]['request_count'].append(len(completed_req))
276298

277299
if routing_type == 'smart':
278300
results[routing_type]['estimated_latency'].append(estimated_latency_cur)
@@ -288,18 +310,29 @@ def main():
288310
print(f'QPS: {rates_lo[i]} (lo), {rates_hi[i]} (hi)')
289311
print(f'% of lo requests below target: {pct_below_target_lo}%')
290312
print(f'% of hi requests below target: {pct_below_target_hi}%')
313+
print(f"prefill_queue_size {np.mean(prefill_queue_size)}")
314+
print(f"pending_tokens_perc {np.mean(pending_tokens_at_arrival_perc)}")
315+
print(f"actual_tokens_perc {np.mean(actual_tokens_at_arrival_perc)}")
316+
sys.stdout.flush()
317+
318+
291319

292320
# Create a timestamp
293321
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
294322

295323
# Create the output file name with the timestamp
296-
output_file = f"results_{timestamp}.json"
324+
297325

298326

299327

300328
# Write results to CSV
329+
# Ensure the output directory exists
330+
output_dir = os.path.dirname(output_file)
331+
if not os.path.exists(output_dir):
332+
os.makedirs(output_dir)
333+
301334
with open(output_file, 'w', newline='') as csvfile:
302-
fieldnames = ['RoutingType', 'RateIndex', 'Latency', 'Latency_Lo', 'Latency_Hi','Estimated_Latency', 'Estimated_Latency_lo', 'Estimated_Latency_hi', 'avg_prefill_queue_size', 'avg_pending_tokens_perc', 'avg_actual_tokens_perc' ]
335+
fieldnames = ['RoutingType', 'RateIndex', 'Latency', 'Latency_Lo', 'Latency_Hi', 'avg_prefill_queue_size', 'avg_pending_tokens_perc', 'avg_actual_tokens_perc' , 'pct_below_latency_target_lo', 'pct_below_latency_target_hi']
303336
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
304337

305338
writer.writeheader()
@@ -313,12 +346,11 @@ def main():
313346
'Latency': results[routing_type]['latency'][i],
314347
'Latency_Lo': results[routing_type]['latency_lo'][i],
315348
'Latency_Hi': results[routing_type]['latency_hi'][i],
316-
'Estimated_Latency': results[routing_type]['estimated_latency'][i],
317-
'Estimated_Latency_Lo': results[routing_type]['estimated_latency_lo'][i],
318-
'Estimated_Latency_Hi': results[routing_type]['estimated_latency_hi'][i],
319349
'avg_prefill_queue_size': results[routing_type]['avg_prefill_queue_size'][i],
320350
'avg_pending_tokens_perc': results[routing_type]['avg_pending_tokens_perc'][i],
321351
'avg_actual_tokens_perc': results[routing_type]['avg_actual_tokens_perc'][i],
352+
'pct_below_latency_target_lo': results[routing_type]['pct_below_latency_target_lo'][i],
353+
'pct_below_latency_target_hi': results[routing_type]['pct_below_latency_target_hi'][i],
322354
})
323355

324356
print(f"Results have been saved to {output_file}")

0 commit comments

Comments
 (0)