@@ -49,7 +49,7 @@ def estimate_avg_latency(self, llmactor, input_size, output_size, include_runnin
49
49
items = llmactor .decode_store .items if include_running_requests else llmactor .decoded_store .items
50
50
51
51
for item in items :
52
- if include_running_requests or self .env .now - item .arrival_time > TTL :
52
+ if self .env .now - item .arrival_time > TTL :
53
53
continue # Skip long-running requests
54
54
55
55
tokens_in_kv_cache_at_start_of_decode = item .tokens_in_kv_cache_at_start_of_decode or 0
@@ -325,14 +325,14 @@ def find_target_pod(self, routing_type, input_size, output_size, target_latency
325
325
return target_pod , latency_esimated
326
326
327
327
328
- def queueing_signal (self , use_overall = True ) -> bool :
329
- if not use_overall :
328
+ def queueing_signal (self , routing_type = "smart" ) -> bool :
329
+ if routing_type == "smart" :
330
330
return self .check_saturations (use_pseudo_kv_cache = True , max_saturation = self .queueing_perc ) or self .all_servers_queued ()
331
331
else :
332
332
return self .get_overall_pending_tokens_perc () > self .queueing_perc or self .all_servers_queued ()
333
333
334
- def dequeueing_signal (self , use_overall = True ) -> bool :
335
- if not use_overall :
334
+ def dequeueing_signal (self , routing_type = "smart" ) -> bool :
335
+ if routing_type == "smart" :
336
336
return self .check_saturations (use_pseudo_kv_cache = True , max_saturation = self .queueing_perc ) == False and self .all_servers_queued () == False
337
337
else :
338
338
return self .get_overall_pending_tokens_perc () < self .queueing_perc and self .all_servers_queued () == False
@@ -357,11 +357,10 @@ def dequeue(self) -> Optional[Request]:
357
357
358
358
def dequeue_process (self , routing_type , drop_late_requests = False ):
359
359
while True :
360
- if not self .check_if_queues_empty () and self .dequeueing_signal ():
360
+ if not self .check_if_queues_empty () and self .dequeueing_signal (routing_type ):
361
361
# Get the request with the highest SLO violation
362
362
req = self .dequeue ()
363
363
if req :
364
- #if self.env.now - req.arrival_time > req.target_latency:
365
364
if (drop_late_requests == False ) or (self .env .now - req .arrival_time < 100 * req .target_latency ): #ad-hoc
366
365
target_pod , estimated_latency = self .find_target_pod (routing_type , req .input_size , req .output_size , req .target_latency , req .lora )
367
366
req .target_pod = target_pod .id
@@ -467,7 +466,7 @@ def allPodsRunningCritical(self):
467
466
468
467
def generate_request_inference_gateway (
469
468
self , rate , lora_requested , target_latency_list , prefix_latency_list ,
470
- routing_type = "random" , prompt_output_tuple = None , mean_request_size = None ,
469
+ routing_type , prompt_output_tuple = None , mean_request_size = None ,
471
470
std_request_size = None , mean_output_size = None , std_output_size = None ,
472
471
estimated_output_size = None ):
473
472
"""
@@ -495,7 +494,7 @@ def generate_request_inference_gateway(
495
494
cnt += 1
496
495
self .messages_remaining_cnt -= 1
497
496
498
- if self .should_enqueue_request ():
497
+ if self .should_enqueue_request (routing_type ):
499
498
self .enqueue_request (new_req , lora_requested , target_latency )
500
499
else :
501
500
self .route_request (new_req , routing_type , input_size , output_size , target_latency , lora_requested , estimated_output_size )
@@ -517,8 +516,8 @@ def create_request(self, request_id, input_size, output_size, target_latency):
517
516
new_req .target_latency = target_latency
518
517
return new_req
519
518
520
- def should_enqueue_request (self ):
521
- return self .queueing_signal () or not self .check_if_queues_empty ()
519
+ def should_enqueue_request (self , routing_type ):
520
+ return self .queueing_signal (routing_type ) or not self .check_if_queues_empty ()
522
521
523
522
def enqueue_request (self , new_req , lora_requested , target_latency ):
524
523
if lora_requested :
0 commit comments