Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 220915f

Browse files
authored
Merge pull request #113 from shmuelk/post-response
feat: Add the invocation of the Post response plugins
2 parents e0eee4c + 6fffe9e commit 220915f

File tree

6 files changed

+165
-26
lines changed

6 files changed

+165
-26
lines changed

pkg/epp/handlers/server.go

+44-15
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3838
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
3939
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
40+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4041
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4142
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
4243
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -66,6 +67,7 @@ type StreamingServer struct {
6667

6768
type Scheduler interface {
6869
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
70+
RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error)
6971
}
7072

7173
// RequestContext stores context information during the life time of an HTTP request.
@@ -189,6 +191,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
189191
case *extProcPb.ProcessingRequest_RequestTrailers:
190192
// This is currently unused.
191193
case *extProcPb.ProcessingRequest_ResponseHeaders:
194+
responseHeaders := make(map[string]string)
192195
for _, header := range v.ResponseHeaders.Headers.GetHeaders() {
193196
value := string(header.RawValue)
194197

@@ -199,27 +202,53 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
199202
reqCtx.modelServerStreaming = true
200203
loggerTrace.Info("model server is streaming response")
201204
}
205+
responseHeaders[header.Key] = value
202206
}
203207

204-
reqCtx.RequestState = ResponseRecieved
205-
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
206-
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
207-
ResponseHeaders: &extProcPb.HeadersResponse{
208-
Response: &extProcPb.CommonResponse{
209-
HeaderMutation: &extProcPb.HeaderMutation{
210-
SetHeaders: []*configPb.HeaderValueOption{
211-
{
212-
Header: &configPb.HeaderValue{
213-
// This is for debugging purpose only.
214-
Key: "x-went-into-resp-headers",
215-
RawValue: []byte("true"),
216-
},
217-
},
208+
llmReq := &schedulingtypes.LLMRequest{
209+
Model: reqCtx.Model,
210+
Headers: responseHeaders,
211+
ResolvedTargetModel: reqCtx.ResolvedTargetModel,
212+
}
213+
214+
var result *types.Result
215+
result, err = s.scheduler.RunPostResponsePlugins(ctx, llmReq, reqCtx.TargetPod)
216+
if err != nil {
217+
logger.V(logutil.DEFAULT).Error(err, "Error handling response")
218+
reqCtx.ResponseStatusCode = errutil.ModelServerError
219+
} else {
220+
headers := []*configPb.HeaderValueOption{
221+
{
222+
Header: &configPb.HeaderValue{
223+
// This is for debugging purpose only.
224+
Key: "x-went-into-resp-headers",
225+
RawValue: []byte("true"),
226+
},
227+
},
228+
}
229+
230+
// Add headers added by PostResponse
231+
for key, value := range result.MutatedHeaders {
232+
headers = append(headers, &configPb.HeaderValueOption{
233+
Header: &configPb.HeaderValue{
234+
Key: key,
235+
RawValue: []byte(value),
236+
},
237+
})
238+
}
239+
240+
reqCtx.RequestState = ResponseRecieved
241+
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
242+
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
243+
ResponseHeaders: &extProcPb.HeadersResponse{
244+
Response: &extProcPb.CommonResponse{
245+
HeaderMutation: &extProcPb.HeaderMutation{
246+
SetHeaders: headers,
218247
},
219248
},
220249
},
221250
},
222-
},
251+
}
223252
}
224253

225254
case *extProcPb.ProcessingRequest_ResponseBody:

pkg/epp/scheduling/config.go

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type SchedulerConfig struct {
2626
scorers map[plugins.Scorer]int // map from scorer to weight
2727
picker plugins.Picker
2828
postSchedulePlugins []plugins.PostSchedule
29+
postResponsePlugins []plugins.PostResponse
2930
}
3031

3132
var defPlugin = &defaultPlugin{}
@@ -40,4 +41,5 @@ var defaultConfig = &SchedulerConfig{
4041
scorers: map[plugins.Scorer]int{},
4142
picker: defPlugin,
4243
postSchedulePlugins: []plugins.PostSchedule{},
44+
postResponsePlugins: []plugins.PostResponse{},
4345
}

pkg/epp/scheduling/local_config.go

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ const (
3636
loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT"
3737
)
3838

39+
func init() {
40+
setDefaultConfig()
41+
}
42+
3943
func setDefaultConfig() {
4044
// since the default config is a global variable, we add this function to minimize rebase conflicts.
4145
// this configuration is a temporary state, it should be better streamlined.

pkg/epp/scheduling/scheduler.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ var (
6969
)
7070

7171
func NewScheduler(datastore Datastore) *Scheduler {
72-
setDefaultConfig()
7372
return NewSchedulerWithConfig(datastore, defaultConfig)
7473
}
7574

@@ -81,6 +80,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched
8180
scorers: config.scorers,
8281
picker: config.picker,
8382
postSchedulePlugins: config.postSchedulePlugins,
83+
postResponsePlugins: config.postResponsePlugins,
8484
}
8585
}
8686

@@ -91,6 +91,7 @@ type Scheduler struct {
9191
scorers map[plugins.Scorer]int // map from scorer to its weight
9292
picker plugins.Picker
9393
postSchedulePlugins []plugins.PostSchedule
94+
postResponsePlugins []plugins.PostResponse
9495
}
9596

9697
type Datastore interface {
@@ -211,6 +212,38 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty
211212
}
212213
}
213214

215+
func (s *Scheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) {
216+
logger := log.FromContext(ctx)
217+
218+
pool, err := s.datastore.PoolGet()
219+
if err != nil {
220+
return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods
221+
}
222+
223+
// Snapshot pod metrics from the datastore to:
224+
// 1. Reduce concurrent access to the datastore.
225+
// 2. Ensure consistent data during the scheduling operation of a request.
226+
pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll())
227+
var targetPod types.Pod
228+
for _, pod := range pods {
229+
if pod.GetPod().NamespacedName.String() == targetPodName {
230+
targetPod = pod
231+
break
232+
}
233+
}
234+
235+
sCtx := types.NewSchedulingContext(ctx, req, pods, pool.Spec.TargetPortNumber)
236+
237+
for _, plugin := range s.postResponsePlugins {
238+
logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name())
239+
before := time.Now()
240+
plugin.PostResponse(sCtx, targetPod)
241+
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before))
242+
}
243+
244+
return &types.Result{TargetPod: nil, MutatedHeaders: sCtx.MutatedHeaders}, nil
245+
}
246+
214247
type defaultPlugin struct {
215248
picker.RandomPicker
216249
}

pkg/epp/scheduling/scheduler_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,56 @@ func TestSchedulePlugins(t *testing.T) {
483483
}
484484
}
485485

486+
func TestPostResponse(t *testing.T) {
487+
pr1 := &testPostResponse{
488+
NameRes: "pr1",
489+
ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"},
490+
ReceivedResponseHeaders: make(map[string]string),
491+
}
492+
493+
tests := []struct {
494+
name string
495+
config SchedulerConfig
496+
input []*backendmetrics.FakePodMetrics
497+
responseHeaders map[string]string
498+
wantMutatedHeaders map[string]string
499+
}{
500+
{
501+
name: "Simple postResponse test",
502+
config: SchedulerConfig{
503+
postResponsePlugins: []plugins.PostResponse{pr1},
504+
},
505+
input: []*backendmetrics.FakePodMetrics{
506+
{Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
507+
},
508+
responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"},
509+
wantMutatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"},
510+
},
511+
}
512+
513+
for _, test := range tests {
514+
scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config)
515+
516+
req := &types.LLMRequest{
517+
Model: "test-model",
518+
Headers: test.responseHeaders,
519+
}
520+
521+
result, err := scheduler.RunPostResponsePlugins(context.Background(), req, test.input[0].Pod.NamespacedName.String())
522+
if err != nil {
523+
t.Errorf("Received an error. Error: %s", err)
524+
}
525+
526+
if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" {
527+
t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff)
528+
}
529+
530+
if diff := cmp.Diff(test.wantMutatedHeaders, result.MutatedHeaders); diff != "" {
531+
t.Errorf("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v", diff)
532+
}
533+
}
534+
}
535+
486536
type fakeDataStore struct {
487537
pods []*backendmetrics.FakePodMetrics
488538
}
@@ -571,6 +621,23 @@ func (tp *TestPlugin) reset() {
571621
tp.NumOfPickerCandidates = 0
572622
}
573623

624+
type testPostResponse struct {
625+
NameRes string
626+
ReceivedResponseHeaders map[string]string
627+
ExtraHeaders map[string]string
628+
}
629+
630+
func (pr *testPostResponse) Name() string { return pr.NameRes }
631+
632+
func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {
633+
for key, value := range ctx.Req.Headers {
634+
pr.ReceivedResponseHeaders[key] = value
635+
}
636+
for key, value := range pr.ExtraHeaders {
637+
ctx.MutatedHeaders[key] = value
638+
}
639+
}
640+
574641
func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod {
575642
res := []types.Pod{}
576643
for _, pod := range ctx.PodsSnapshot {

pkg/epp/scheduling/scorers_test.go

+14-10
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,23 @@ func TestScorers(t *testing.T) {
8686
},
8787
},
8888
wantRes: &types.Result{
89-
TargetPod: &types.PodMetrics{
90-
Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}},
91-
Metrics: &backendmetrics.Metrics{
92-
WaitingQueueSize: 0,
93-
KVCacheUsagePercent: 0.2,
94-
MaxActiveModels: 2,
95-
ActiveModels: map[string]int{
96-
"foo": 1,
97-
"bar": 1,
89+
TargetPod: &types.ScoredPod{
90+
Pod: &types.PodMetrics{
91+
Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}},
92+
Metrics: &backendmetrics.Metrics{
93+
WaitingQueueSize: 0,
94+
KVCacheUsagePercent: 0.2,
95+
MaxActiveModels: 2,
96+
ActiveModels: map[string]int{
97+
"foo": 1,
98+
"bar": 1,
99+
},
100+
WaitingModels: map[string]int{},
98101
},
99-
WaitingModels: map[string]int{},
100102
},
103+
Score: 0.5,
101104
},
105+
MutatedHeaders: map[string]string{},
102106
},
103107
},
104108
}

0 commit comments

Comments
 (0)