@@ -483,6 +483,56 @@ func TestSchedulePlugins(t *testing.T) {
483
483
}
484
484
}
485
485
486
+ func TestPostResponse (t * testing.T ) {
487
+ pr1 := & testPostResponse {
488
+ NameRes : "pr1" ,
489
+ ExtraHeaders : map [string ]string {"x-session-id" : "qwer-asdf-zxcv" },
490
+ ReceivedResponseHeaders : make (map [string ]string ),
491
+ }
492
+
493
+ tests := []struct {
494
+ name string
495
+ config SchedulerConfig
496
+ input []* backendmetrics.FakePodMetrics
497
+ responseHeaders map [string ]string
498
+ wantMutatedHeaders map [string ]string
499
+ }{
500
+ {
501
+ name : "Simple postResponse test" ,
502
+ config : SchedulerConfig {
503
+ postResponsePlugins : []plugins.PostResponse {pr1 },
504
+ },
505
+ input : []* backendmetrics.FakePodMetrics {
506
+ {Pod : & backendmetrics.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod1" }}},
507
+ },
508
+ responseHeaders : map [string ]string {"Content-type" : "application/json" , "Content-Length" : "1234" },
509
+ wantMutatedHeaders : map [string ]string {"x-session-id" : "qwer-asdf-zxcv" },
510
+ },
511
+ }
512
+
513
+ for _ , test := range tests {
514
+ scheduler := NewSchedulerWithConfig (& fakeDataStore {pods : test .input }, & test .config )
515
+
516
+ req := & types.LLMRequest {
517
+ Model : "test-model" ,
518
+ Headers : test .responseHeaders ,
519
+ }
520
+
521
+ result , err := scheduler .RunPostResponsePlugins (context .Background (), req , test .input [0 ].Pod .NamespacedName .String ())
522
+ if err != nil {
523
+ t .Errorf ("Received an error. Error: %s" , err )
524
+ }
525
+
526
+ if diff := cmp .Diff (test .responseHeaders , pr1 .ReceivedResponseHeaders ); diff != "" {
527
+ t .Errorf ("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v" , diff )
528
+ }
529
+
530
+ if diff := cmp .Diff (test .wantMutatedHeaders , result .MutatedHeaders ); diff != "" {
531
+ t .Errorf ("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v" , diff )
532
+ }
533
+ }
534
+ }
535
+
486
536
type fakeDataStore struct {
487
537
pods []* backendmetrics.FakePodMetrics
488
538
}
@@ -571,6 +621,23 @@ func (tp *TestPlugin) reset() {
571
621
tp .NumOfPickerCandidates = 0
572
622
}
573
623
624
+ type testPostResponse struct {
625
+ NameRes string
626
+ ReceivedResponseHeaders map [string ]string
627
+ ExtraHeaders map [string ]string
628
+ }
629
+
630
+ func (pr * testPostResponse ) Name () string { return pr .NameRes }
631
+
632
+ func (pr * testPostResponse ) PostResponse (ctx * types.SchedulingContext , pod types.Pod ) {
633
+ for key , value := range ctx .Req .Headers {
634
+ pr .ReceivedResponseHeaders [key ] = value
635
+ }
636
+ for key , value := range pr .ExtraHeaders {
637
+ ctx .MutatedHeaders [key ] = value
638
+ }
639
+ }
640
+
574
641
func findPods (ctx * types.SchedulingContext , names ... k8stypes.NamespacedName ) []types.Pod {
575
642
res := []types.Pod {}
576
643
for _ , pod := range ctx .PodsSnapshot {
0 commit comments