From f55e66eccff9e5583b6cde8e52e5a2f9d6af9fb2 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 13:54:17 +0300 Subject: [PATCH 1/7] Added invocation of PostResponse plugins Signed-off-by: Shmuel Kallner --- pkg/epp/handlers/server.go | 58 ++++++++++++++++++++++++--------- pkg/epp/scheduling/config.go | 6 +++- pkg/epp/scheduling/scheduler.go | 33 +++++++++++++++++++ 3 files changed, 81 insertions(+), 16 deletions(-) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 37d84027e..669547b2d 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -67,6 +67,7 @@ type StreamingServer struct { type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) + OnResponse(ctx context.Context, req *schedulingtypes.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error) } // RequestContext stores context information during the life time of an HTTP request. @@ -203,6 +204,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) case *extProcPb.ProcessingRequest_RequestTrailers: // This is currently unused. case *extProcPb.ProcessingRequest_ResponseHeaders: + responseHeaders := make(map[string]string) for _, header := range v.ResponseHeaders.Headers.GetHeaders() { value := string(header.RawValue) @@ -213,26 +215,52 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx.modelServerStreaming = true loggerTrace.Info("model server is streaming response") } + responseHeaders[header.Key] = value } - reqCtx.RequestState = ResponseRecieved - reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - // This is for debugging purpose only. - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, + + llmReq := &schedulingtypes.LLMRequest{ + Model: reqCtx.Model, + Headers: responseHeaders, + ResolvedTargetModel: reqCtx.ResolvedTargetModel, + } + + var result *schedulingtypes.Result + result, err = s.scheduler.OnResponse(ctx, llmReq, reqCtx.TargetPod) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error handling response") + reqCtx.ResponseStatusCode = errutil.ModelServerError + } else { + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + // This is for debugging purpose only. + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + } + + // Add headers added by PostResponse + for key, value := range result.MutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } + reqCtx.RequestState = ResponseRecieved + reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, }, }, }, }, - }, + } } case *extProcPb.ProcessingRequest_ResponseBody: diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index c02d9f56c..2f8ab178e 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -20,13 +20,15 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" // NewSchedulerConfig creates a new SchedulerConfig object with the given plugins. func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int, - picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule) *SchedulerConfig { + picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, + postResponsePlugins []plugins.PostResponse) *SchedulerConfig { return &SchedulerConfig{ preSchedulePlugins: preSchedulePlugins, filters: filters, scorers: scorers, picker: picker, postSchedulePlugins: postSchedulePlugins, + postResponsePlugins: postResponsePlugins, } } @@ -37,6 +39,7 @@ type SchedulerConfig struct { scorers map[plugins.Scorer]int // map from scorer to weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } var defPlugin = &defaultPlugin{} @@ -51,4 +54,5 @@ var defaultConfig = &SchedulerConfig{ scorers: map[plugins.Scorer]int{}, picker: defPlugin, postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, } diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 78a4f93de..2d3268a85 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -79,6 +79,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched scorers: config.scorers, picker: config.picker, postSchedulePlugins: config.postSchedulePlugins, + postResponsePlugins: config.postResponsePlugins, } } @@ -89,6 +90,7 @@ type Scheduler struct { scorers map[plugins.Scorer]int // map from scorer to its weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } type Datastore interface { @@ -208,6 +210,37 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty } } +// OnResponse is invoked during the processing of a response from an inference pod. It will invoke +// any defined plugins that process the response. +func (s *Scheduler) OnResponse(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll()) + var targetPod types.Pod + for _, pod := range pods { + if pod.GetPod().NamespacedName.String() == targetPodName { + targetPod = pod + break + } + } + + sCtx := types.NewSchedulingContext(ctx, req, pods) + + s.runPostResponsePlugins(sCtx, targetPod) + + return &types.Result{TargetPod: nil, MutatedHeaders: sCtx.MutatedHeaders}, nil +} + +func (s *Scheduler) runPostResponsePlugins(ctx *types.SchedulingContext, targetPod types.Pod) { + for _, plugin := range s.postResponsePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostResponse(ctx, targetPod) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before)) + } +} + type defaultPlugin struct { picker.RandomPicker } From 91224cdd5a344872e8efd0edf7ea392e32d6ea14 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 13:54:42 +0300 Subject: [PATCH 2/7] Added test of invocation of PostResponse plugins Signed-off-by: Shmuel Kallner --- pkg/epp/scheduling/scheduler_test.go | 70 ++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index da2874c06..0dc0d7b87 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -468,6 +468,59 @@ func TestSchedulePlugins(t *testing.T) { } } +func TestPostResponse(t *testing.T) { + pr1 := &testPostResponse{ + NameRes: "pr1", + ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + ReceivedResponseHeaders: make(map[string]string), + } + + targetPod := k8stypes.NamespacedName{Name: "pod2"} + + tests := []struct { + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + responseHeaders map[string]string + wantMutatedHeaders map[string]string + }{ + { + name: "Simple postResponse test", + config: SchedulerConfig{ + postResponsePlugins: []plugins.PostResponse{pr1}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: targetPod}}, + }, + responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"}, + wantMutatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + }, + } + + for _, test := range tests { + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) + + req := &types.LLMRequest{ + Model: "test-model", + Headers: test.responseHeaders, + } + + result, err := scheduler.OnResponse(context.Background(), req, targetPod.String()) + if err != nil { + t.Errorf("Received an error. Error: %s", err) + } + + if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" { + t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff) + } + + if diff := cmp.Diff(test.wantMutatedHeaders, result.MutatedHeaders); diff != "" { + t.Errorf("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v", diff) + } + } +} + type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -548,6 +601,23 @@ func (tp *TestPlugin) reset() { tp.NumOfPickerCandidates = 0 } +type testPostResponse struct { + NameRes string + ReceivedResponseHeaders map[string]string + ExtraHeaders map[string]string +} + +func (pr *testPostResponse) Name() string { return pr.NameRes } + +func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { + for key, value := range ctx.Req.Headers { + pr.ReceivedResponseHeaders[key] = value + } + for key, value := range pr.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } +} + func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { res := []types.Pod{} for _, pod := range ctx.PodsSnapshot { From 3c1c7ce123b2843c42532e40c06f3139c0e9d4a9 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 15:39:13 +0300 Subject: [PATCH 3/7] Update pkg/epp/handlers/server.go Co-authored-by: Etai Lev Ran --- pkg/epp/handlers/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 669547b2d..8efe0c084 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -67,7 +67,7 @@ type StreamingServer struct { type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) - OnResponse(ctx context.Context, req *schedulingtypes.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error) + OnResponse(ctx context.Context, req *schedulingtypes.LLMRequest, targetPodName string) (*schedulingtypes.Result, error) } // RequestContext stores context information during the life time of an HTTP request. From 3a2d92cb48ee8aed5cb4b6af64263638692e39e4 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 18:55:20 +0300 Subject: [PATCH 4/7] Added a strcut to contain response information Signed-off-by: Shmuel Kallner --- pkg/epp/scheduling/types/types.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index aaefcf5ee..1a0d81e87 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -45,6 +45,21 @@ func (r *LLMRequest) String() string { r.Model, r.ResolvedTargetModel, r.Critical, len(r.Prompt), r.Headers) } +// LLMResponse contains information from the response received to be passed to plugins +type LLMResponse struct { + // Headers is a map of the response headers. Nil during body processing + Headers map[string]string + + // Body Is the body of the response or nil during header processing + Body string + + // IsStreaming indicates whether or not the response is being streamed by the model + IsSreaming bool + + // EndOfStream when true indicates that this invocation contains the last chunk of the response + EndOfStream bool +} + type Pod interface { GetPod() *backend.Pod GetMetrics() *backendmetrics.Metrics @@ -61,6 +76,7 @@ type SchedulingContext struct { context.Context Logger logr.Logger Req *LLMRequest + Resp *LLMResponse PodsSnapshot []Pod MutatedHeaders map[string]string } @@ -85,12 +101,13 @@ type PodMetrics struct { *backendmetrics.Metrics } -func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { +func NewSchedulingContext(ctx context.Context, req *LLMRequest, resp *LLMResponse, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ Context: ctx, Logger: logger, Req: req, + Resp: resp, PodsSnapshot: pods, MutatedHeaders: make(map[string]string), } From b372d7356a247464462274564c40b40fe0113791 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 18:56:55 +0300 Subject: [PATCH 5/7] Pass nil for LLMResponse during request processing Signed-off-by: Shmuel Kallner --- pkg/epp/scheduling/plugins/filter/filter_test.go | 6 +++--- pkg/epp/scheduling/plugins/scorer/kvcache_test.go | 2 +- pkg/epp/scheduling/plugins/scorer/queue_test.go | 2 +- pkg/epp/scheduling/scheduler.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/epp/scheduling/plugins/filter/filter_test.go b/pkg/epp/scheduling/plugins/filter/filter_test.go index 2354c3ef5..0fdae5659 100644 --- a/pkg/epp/scheduling/plugins/filter/filter_test.go +++ b/pkg/epp/scheduling/plugins/filter/filter_test.go @@ -52,7 +52,7 @@ func TestFilter(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input) got := test.filter.Filter(ctx, test.input) if diff := cmp.Diff(test.output, got); diff != "" { @@ -187,7 +187,7 @@ func TestFilterFunc(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input) got := test.f(ctx, test.input) if diff := cmp.Diff(test.output, got); diff != "" { @@ -244,7 +244,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, } - ctx := types.NewSchedulingContext(context.Background(), req, pods) + ctx := types.NewSchedulingContext(context.Background(), req, nil, pods) // Run the filter function multiple times and count the results affinityCount := 0 diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache_test.go b/pkg/epp/scheduling/plugins/scorer/kvcache_test.go index 257a58c17..68be8a213 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache_test.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache_test.go @@ -82,7 +82,7 @@ func TestKvCacheScorer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods) + ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods) scorer := &KVCacheScorer{} scores := scorer.Score(ctx, tt.pods) diff --git a/pkg/epp/scheduling/plugins/scorer/queue_test.go b/pkg/epp/scheduling/plugins/scorer/queue_test.go index 907681b25..d60eab66a 100644 --- a/pkg/epp/scheduling/plugins/scorer/queue_test.go +++ b/pkg/epp/scheduling/plugins/scorer/queue_test.go @@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods) + ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods) scores := scorer.Score(ctx, tt.pods) for i, pod := range tt.pods { diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 2d3268a85..332270855 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -110,7 +110,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) + sCtx := types.NewSchedulingContext(ctx, req, nil, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot)) s.runPreSchedulePlugins(sCtx) From ea5f2d7ed26b085fae06a15690fa5face56b871b Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 19:25:50 +0300 Subject: [PATCH 6/7] Pass LLMResponse to Scheduler.OnResponse, rather than LLMRequest Signed-off-by: Shmuel Kallner --- pkg/epp/handlers/server.go | 10 ++++------ pkg/epp/scheduling/scheduler.go | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 8efe0c084..d9ad6cfc0 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -67,7 +67,7 @@ type StreamingServer struct { type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) - OnResponse(ctx context.Context, req *schedulingtypes.LLMRequest, targetPodName string) (*schedulingtypes.Result, error) + OnResponse(ctx context.Context, resp *schedulingtypes.LLMResponse, targetPodName string) (*schedulingtypes.Result, error) } // RequestContext stores context information during the life time of an HTTP request. @@ -218,14 +218,12 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) responseHeaders[header.Key] = value } - llmReq := &schedulingtypes.LLMRequest{ - Model: reqCtx.Model, - Headers: responseHeaders, - ResolvedTargetModel: reqCtx.ResolvedTargetModel, + llmResp := &schedulingtypes.LLMResponse{ + Headers: responseHeaders, } var result *schedulingtypes.Result - result, err = s.scheduler.OnResponse(ctx, llmReq, reqCtx.TargetPod) + result, err = s.scheduler.OnResponse(ctx, llmResp, reqCtx.TargetPod) if err != nil { logger.V(logutil.DEFAULT).Error(err, "Error handling response") reqCtx.ResponseStatusCode = errutil.ModelServerError diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 332270855..d254a1f69 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -212,7 +212,7 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty // OnResponse is invoked during the processing of a response from an inference pod. It will invoke // any defined plugins that process the response. -func (s *Scheduler) OnResponse(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { +func (s *Scheduler) OnResponse(ctx context.Context, resp *types.LLMResponse, targetPodName string) (*types.Result, error) { // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. @@ -225,7 +225,7 @@ func (s *Scheduler) OnResponse(ctx context.Context, req *types.LLMRequest, targe } } - sCtx := types.NewSchedulingContext(ctx, req, pods) + sCtx := types.NewSchedulingContext(ctx, nil, resp, pods) s.runPostResponsePlugins(sCtx, targetPod) From 3e42367251c5bc145a81bd946f39d18a24ede244 Mon Sep 17 00:00:00 2001 From: Shmuel Kallner Date: Wed, 7 May 2025 19:26:28 +0300 Subject: [PATCH 7/7] Update scheduler tests for newer APIs Signed-off-by: Shmuel Kallner --- pkg/epp/scheduling/scheduler_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 0dc0d7b87..854f26d05 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -501,12 +501,11 @@ func TestPostResponse(t *testing.T) { for _, test := range tests { scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) - req := &types.LLMRequest{ - Model: "test-model", + resp := &types.LLMResponse{ Headers: test.responseHeaders, } - result, err := scheduler.OnResponse(context.Background(), req, targetPod.String()) + result, err := scheduler.OnResponse(context.Background(), resp, targetPod.String()) if err != nil { t.Errorf("Received an error. Error: %s", err) } @@ -610,7 +609,7 @@ type testPostResponse struct { func (pr *testPostResponse) Name() string { return pr.NameRes } func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { - for key, value := range ctx.Req.Headers { + for key, value := range ctx.Resp.Headers { pr.ReceivedResponseHeaders[key] = value } for key, value := range pr.ExtraHeaders {