Skip to content

Commit 1fdb7b5

Browse files
committed
update constants
1 parent d793532 commit 1fdb7b5

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed
Binary file not shown.

simulations/llm_ig_simulation/loadbalancer.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def estimate_avg_latency(self, llmactor, input_size, output_size, include_runnin
4949
items = llmactor.decode_store.items if include_running_requests else llmactor.decoded_store.items
5050

5151
for item in items:
52-
if include_running_requests or self.env.now - item.arrival_time > TTL:
52+
if self.env.now - item.arrival_time > TTL:
5353
continue # Skip long-running requests
5454

5555
tokens_in_kv_cache_at_start_of_decode = item.tokens_in_kv_cache_at_start_of_decode or 0
@@ -325,14 +325,14 @@ def find_target_pod(self, routing_type, input_size, output_size, target_latency
325325
return target_pod, latency_esimated
326326

327327

328-
def queueing_signal(self, use_overall = True) -> bool:
329-
if not use_overall:
328+
def queueing_signal(self, routing_type = "smart") -> bool:
329+
if routing_type == "smart":
330330
return self.check_saturations(use_pseudo_kv_cache=True, max_saturation= self.queueing_perc) or self.all_servers_queued()
331331
else :
332332
return self.get_overall_pending_tokens_perc() > self.queueing_perc or self.all_servers_queued()
333333

334-
def dequeueing_signal(self, use_overall = True) -> bool:
335-
if not use_overall:
334+
def dequeueing_signal(self, routing_type = "smart") -> bool:
335+
if routing_type == "smart":
336336
return self.check_saturations(use_pseudo_kv_cache=True, max_saturation= self.queueing_perc) == False and self.all_servers_queued() == False
337337
else :
338338
return self.get_overall_pending_tokens_perc() < self.queueing_perc and self.all_servers_queued() == False
@@ -357,11 +357,10 @@ def dequeue(self) -> Optional[Request]:
357357

358358
def dequeue_process(self, routing_type, drop_late_requests = False):
359359
while True:
360-
if not self.check_if_queues_empty() and self.dequeueing_signal():
360+
if not self.check_if_queues_empty() and self.dequeueing_signal(routing_type):
361361
# Get the request with the highest SLO violation
362362
req = self.dequeue()
363363
if req:
364-
#if self.env.now - req.arrival_time > req.target_latency:
365364
if (drop_late_requests == False) or (self.env.now - req.arrival_time < 100*req.target_latency): #ad-hoc
366365
target_pod, estimated_latency = self.find_target_pod(routing_type, req.input_size, req.output_size, req.target_latency, req.lora)
367366
req.target_pod = target_pod.id
@@ -467,7 +466,7 @@ def allPodsRunningCritical(self):
467466

468467
def generate_request_inference_gateway(
469468
self, rate, lora_requested, target_latency_list, prefix_latency_list,
470-
routing_type="random", prompt_output_tuple=None, mean_request_size=None,
469+
routing_type, prompt_output_tuple=None, mean_request_size=None,
471470
std_request_size=None, mean_output_size=None, std_output_size=None,
472471
estimated_output_size=None):
473472
"""
@@ -495,7 +494,7 @@ def generate_request_inference_gateway(
495494
cnt += 1
496495
self.messages_remaining_cnt -= 1
497496

498-
if self.should_enqueue_request():
497+
if self.should_enqueue_request(routing_type):
499498
self.enqueue_request(new_req, lora_requested, target_latency)
500499
else:
501500
self.route_request(new_req, routing_type, input_size, output_size, target_latency, lora_requested, estimated_output_size)
@@ -517,8 +516,8 @@ def create_request(self, request_id, input_size, output_size, target_latency):
517516
new_req.target_latency = target_latency
518517
return new_req
519518

520-
def should_enqueue_request(self):
521-
return self.queueing_signal() or not self.check_if_queues_empty()
519+
def should_enqueue_request(self, routing_type):
520+
return self.queueing_signal(routing_type) or not self.check_if_queues_empty()
522521

523522
def enqueue_request(self, new_req, lora_requested, target_latency):
524523
if lora_requested:

simulations/llm_ig_simulation/main.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
def main():
99
parser = argparse.ArgumentParser(description="Simulate LLM load balancing with configurable parameters.")
10-
parser.add_argument("--rates_lo", nargs='+', type=int, default=[30, ], help="List of low rates.")
11-
parser.add_argument("--rates_hi", nargs='+', type=int, default=[30,], help="List of high rates.")
10+
parser.add_argument("--rates_lo", nargs='+', type=int, default=[35, 30, 25, 20, 15, 10, 5, 1], help="List of low rates.")
11+
parser.add_argument("--rates_hi", nargs='+', type=int, default=[35, 30, 25, 20, 15, 10, 5, 1], help="List of high rates.")
1212
parser.add_argument("--no_of_messages", type=int, default=2500, help="Number of messages to simulate.")
1313
parser.add_argument("--mean_request_size_1", type=int, default=202, help="Mean request size for set 1.")
1414
parser.add_argument("--std_request_size_1", type=int, default=20, help="Standard deviation of request size for set 1.")
@@ -25,6 +25,9 @@ def main():
2525
parser.add_argument('--prefix-latency-lo', nargs='+', type=float, help='List of prefix of target latencies for low priority requests.')
2626
parser.add_argument('--prefix-latency-hi', nargs='+', type=float, help='List of prefix of target latencies for high priority requests.')
2727

28+
29+
parser.add_argument('--number-of-servers', type=int, default=6, help='List of target latencies for high priority requests.')
30+
2831
args = parser.parse_args()
2932

3033
# Use provided arguments or defaults
@@ -51,6 +54,8 @@ def main():
5154

5255
prefix_latency_list_lo = args.prefix_latency_lo if args.prefix_latency_lo else ['lo']
5356
prefix_latency_list_hi = args.prefix_latency_hi if args.prefix_latency_hi else ['hi']
57+
58+
number_of_servers = args.number_of_servers
5459

5560
# Define a structure to store results for all routing types
5661
results = {
@@ -110,7 +115,7 @@ def main():
110115
'tol_lat_time_lo': [], 'tol_lat_time_hi': []},
111116
}
112117

113-
all_routing_types = ["least", ]
118+
all_routing_types = ["least", "smart", "random" ]
114119
prompt_output_tuple = None
115120

116121
# Iterate over routing types
@@ -125,7 +130,6 @@ def main():
125130

126131
# Simpy environment and LLM actors setup
127132
env = simpy.Environment()
128-
number_of_servers =6
129133
list_of_llmactors = [LLMActor(env, 1, id) for id in range(number_of_servers)]
130134
lb = LoadBalancer(env, number_of_servers=number_of_servers, list_of_llmactors=list_of_llmactors, req_dict_prefill=req_dict_prefill, req_dict=req_dict, messages_remaining_cnt=no_of_messages*2)
131135
lb.queueing_perc = queueing_perc

0 commit comments

Comments
 (0)