@@ -37,6 +37,7 @@ import (
37
37
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
38
38
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
39
39
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
40
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
40
41
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
41
42
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
42
43
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -66,6 +67,7 @@ type StreamingServer struct {
66
67
67
68
type Scheduler interface {
68
69
Schedule (ctx context.Context , b * schedulingtypes.LLMRequest ) (result * schedulingtypes.Result , err error )
70
+ RunPostResponsePlugins (ctx context.Context , req * types.LLMRequest , tragetPodName string ) (* schedulingtypes.Result , error )
69
71
}
70
72
71
73
// RequestContext stores context information during the life time of an HTTP request.
@@ -189,6 +191,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
189
191
case * extProcPb.ProcessingRequest_RequestTrailers :
190
192
// This is currently unused.
191
193
case * extProcPb.ProcessingRequest_ResponseHeaders :
194
+ responseHeaders := make (map [string ]string )
192
195
for _ , header := range v .ResponseHeaders .Headers .GetHeaders () {
193
196
value := string (header .RawValue )
194
197
@@ -199,27 +202,53 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
199
202
reqCtx .modelServerStreaming = true
200
203
loggerTrace .Info ("model server is streaming response" )
201
204
}
205
+ responseHeaders [header .Key ] = value
202
206
}
203
207
204
- reqCtx .RequestState = ResponseRecieved
205
- reqCtx .respHeaderResp = & extProcPb.ProcessingResponse {
206
- Response : & extProcPb.ProcessingResponse_ResponseHeaders {
207
- ResponseHeaders : & extProcPb.HeadersResponse {
208
- Response : & extProcPb.CommonResponse {
209
- HeaderMutation : & extProcPb.HeaderMutation {
210
- SetHeaders : []* configPb.HeaderValueOption {
211
- {
212
- Header : & configPb.HeaderValue {
213
- // This is for debugging purpose only.
214
- Key : "x-went-into-resp-headers" ,
215
- RawValue : []byte ("true" ),
216
- },
217
- },
208
+ llmReq := & schedulingtypes.LLMRequest {
209
+ Model : reqCtx .Model ,
210
+ Headers : responseHeaders ,
211
+ ResolvedTargetModel : reqCtx .ResolvedTargetModel ,
212
+ }
213
+
214
+ var result * types.Result
215
+ result , err = s .scheduler .RunPostResponsePlugins (ctx , llmReq , reqCtx .TargetPod )
216
+ if err != nil {
217
+ logger .V (logutil .DEFAULT ).Error (err , "Error handling response" )
218
+ reqCtx .ResponseStatusCode = errutil .ModelServerError
219
+ } else {
220
+ headers := []* configPb.HeaderValueOption {
221
+ {
222
+ Header : & configPb.HeaderValue {
223
+ // This is for debugging purpose only.
224
+ Key : "x-went-into-resp-headers" ,
225
+ RawValue : []byte ("true" ),
226
+ },
227
+ },
228
+ }
229
+
230
+ // Add headers added by PostResponse
231
+ for key , value := range result .MutatedHeaders {
232
+ headers = append (headers , & configPb.HeaderValueOption {
233
+ Header : & configPb.HeaderValue {
234
+ Key : key ,
235
+ RawValue : []byte (value ),
236
+ },
237
+ })
238
+ }
239
+
240
+ reqCtx .RequestState = ResponseRecieved
241
+ reqCtx .respHeaderResp = & extProcPb.ProcessingResponse {
242
+ Response : & extProcPb.ProcessingResponse_ResponseHeaders {
243
+ ResponseHeaders : & extProcPb.HeadersResponse {
244
+ Response : & extProcPb.CommonResponse {
245
+ HeaderMutation : & extProcPb.HeaderMutation {
246
+ SetHeaders : headers ,
218
247
},
219
248
},
220
249
},
221
250
},
222
- },
251
+ }
223
252
}
224
253
225
254
case * extProcPb.ProcessingRequest_ResponseBody :
0 commit comments