Skip to content

Commit cae1fb4

Browse files
shmuelkclubanderson
authored andcommitted
Invoke the PostResponse handlers and send any added headers to the user
1 parent 6974c4f commit cae1fb4

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

pkg/epp/handlers/server.go

+44-15
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3838
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
3939
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
40+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4041
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4142
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
4243
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -66,6 +67,7 @@ type StreamingServer struct {
6667

6768
type Scheduler interface {
6869
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
70+
RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error)
6971
}
7072

7173
// RequestContext stores context information during the life time of an HTTP request.
@@ -189,6 +191,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
189191
case *extProcPb.ProcessingRequest_RequestTrailers:
190192
// This is currently unused.
191193
case *extProcPb.ProcessingRequest_ResponseHeaders:
194+
responseHeaders := make(map[string]string)
192195
for _, header := range v.ResponseHeaders.Headers.GetHeaders() {
193196
value := string(header.RawValue)
194197

@@ -199,27 +202,53 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
199202
reqCtx.modelServerStreaming = true
200203
loggerTrace.Info("model server is streaming response")
201204
}
205+
responseHeaders[header.Key] = value
202206
}
203207

204-
reqCtx.RequestState = ResponseRecieved
205-
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
206-
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
207-
ResponseHeaders: &extProcPb.HeadersResponse{
208-
Response: &extProcPb.CommonResponse{
209-
HeaderMutation: &extProcPb.HeaderMutation{
210-
SetHeaders: []*configPb.HeaderValueOption{
211-
{
212-
Header: &configPb.HeaderValue{
213-
// This is for debugging purpose only.
214-
Key: "x-went-into-resp-headers",
215-
RawValue: []byte("true"),
216-
},
217-
},
208+
llmReq := &schedulingtypes.LLMRequest{
209+
Model: reqCtx.Model,
210+
Headers: responseHeaders,
211+
ResolvedTargetModel: reqCtx.ResolvedTargetModel,
212+
}
213+
214+
var result *types.Result
215+
result, err = s.scheduler.RunPostResponsePlugins(ctx, llmReq, reqCtx.TargetPod)
216+
if err != nil {
217+
logger.V(logutil.DEFAULT).Error(err, "Error handling response")
218+
reqCtx.ResponseStatusCode = errutil.ModelServerError
219+
} else {
220+
headers := []*configPb.HeaderValueOption{
221+
{
222+
Header: &configPb.HeaderValue{
223+
// This is for debugging purpose only.
224+
Key: "x-went-into-resp-headers",
225+
RawValue: []byte("true"),
226+
},
227+
},
228+
}
229+
230+
// Add headers added by PostResponse
231+
for key, value := range result.MutatedHeaders {
232+
headers = append(headers, &configPb.HeaderValueOption{
233+
Header: &configPb.HeaderValue{
234+
Key: key,
235+
RawValue: []byte(value),
236+
},
237+
})
238+
}
239+
240+
reqCtx.RequestState = ResponseRecieved
241+
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
242+
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
243+
ResponseHeaders: &extProcPb.HeadersResponse{
244+
Response: &extProcPb.CommonResponse{
245+
HeaderMutation: &extProcPb.HeaderMutation{
246+
SetHeaders: headers,
218247
},
219248
},
220249
},
221250
},
222-
},
251+
}
223252
}
224253

225254
case *extProcPb.ProcessingRequest_ResponseBody:

0 commit comments

Comments
 (0)