Skip to content

[UPSTREAM-SYNC] Feat: Add invocation of Post response plugins #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: upstream-sync
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ type StreamingServer struct {

type Scheduler interface {
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
OnResponse(ctx context.Context, req *schedulingtypes.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error)
}

// RequestContext stores context information during the life time of an HTTP request.
Expand Down Expand Up @@ -203,6 +204,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
case *extProcPb.ProcessingRequest_RequestTrailers:
// This is currently unused.
case *extProcPb.ProcessingRequest_ResponseHeaders:
responseHeaders := make(map[string]string)
for _, header := range v.ResponseHeaders.Headers.GetHeaders() {
value := string(header.RawValue)

Expand All @@ -213,26 +215,52 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
reqCtx.modelServerStreaming = true
loggerTrace.Info("model server is streaming response")
}
responseHeaders[header.Key] = value
}
reqCtx.RequestState = ResponseRecieved
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
// This is for debugging purpose only.
Key: "x-went-into-resp-headers",
RawValue: []byte("true"),
},
},

llmReq := &schedulingtypes.LLMRequest{
Model: reqCtx.Model,
Headers: responseHeaders,
ResolvedTargetModel: reqCtx.ResolvedTargetModel,
}

var result *schedulingtypes.Result
result, err = s.scheduler.OnResponse(ctx, llmReq, reqCtx.TargetPod)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error handling response")
reqCtx.ResponseStatusCode = errutil.ModelServerError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: will returning an error not send the response back to the user?
If this is an internal logic failure and we've already compute the LLM response, might be wasteful.

} else {
headers := []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
// This is for debugging purpose only.
Key: "x-went-into-resp-headers",
RawValue: []byte("true"),
},
},
}

// Add headers added by PostResponse
for key, value := range result.MutatedHeaders {
headers = append(headers, &configPb.HeaderValueOption{
Header: &configPb.HeaderValue{
Key: key,
RawValue: []byte(value),
},
})
}
reqCtx.RequestState = ResponseRecieved
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
},
},
},
},
}
}

case *extProcPb.ProcessingRequest_ResponseBody:
Expand Down
6 changes: 5 additions & 1 deletion pkg/epp/scheduling/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"

// NewSchedulerConfig creates a new SchedulerConfig object with the given plugins.
func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int,
picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule) *SchedulerConfig {
picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule,
postResponsePlugins []plugins.PostResponse) *SchedulerConfig {
return &SchedulerConfig{
preSchedulePlugins: preSchedulePlugins,
filters: filters,
scorers: scorers,
picker: picker,
postSchedulePlugins: postSchedulePlugins,
postResponsePlugins: postResponsePlugins,
}
}

Expand All @@ -37,6 +39,7 @@ type SchedulerConfig struct {
scorers map[plugins.Scorer]int // map from scorer to weight
picker plugins.Picker
postSchedulePlugins []plugins.PostSchedule
postResponsePlugins []plugins.PostResponse
}

var defPlugin = &defaultPlugin{}
Expand All @@ -51,4 +54,5 @@ var defaultConfig = &SchedulerConfig{
scorers: map[plugins.Scorer]int{},
picker: defPlugin,
postSchedulePlugins: []plugins.PostSchedule{},
postResponsePlugins: []plugins.PostResponse{},
}
33 changes: 33 additions & 0 deletions pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched
scorers: config.scorers,
picker: config.picker,
postSchedulePlugins: config.postSchedulePlugins,
postResponsePlugins: config.postResponsePlugins,
}
}

Expand All @@ -89,6 +90,7 @@ type Scheduler struct {
scorers map[plugins.Scorer]int // map from scorer to its weight
picker plugins.Picker
postSchedulePlugins []plugins.PostSchedule
postResponsePlugins []plugins.PostResponse
}

type Datastore interface {
Expand Down Expand Up @@ -208,6 +210,37 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty
}
}

// OnResponse is invoked during the processing of a response from an inference pod. It will invoke
// any defined plugins that process the response.
func (s *Scheduler) OnResponse(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) {
// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request.
pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll())
var targetPod types.Pod
for _, pod := range pods {
if pod.GetPod().NamespacedName.String() == targetPodName {
targetPod = pod
break
}
}
Comment on lines +221 to +226
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a pod was selected scheduler and crashed before this call, what happens here?
targetPod remains nil?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.
The whole question of failure handling is not sufficiently specified in the system and needs to be addressed. There are many failure scenarios (e.g., EPP crashes during different points in processing, envoy crashes, ...) across and inside components.


sCtx := types.NewSchedulingContext(ctx, req, pods)

s.runPostResponsePlugins(sCtx, targetPod)

return &types.Result{TargetPod: nil, MutatedHeaders: sCtx.MutatedHeaders}, nil
}

func (s *Scheduler) runPostResponsePlugins(ctx *types.SchedulingContext, targetPod types.Pod) {
for _, plugin := range s.postResponsePlugins {
ctx.Logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name())
before := time.Now()
plugin.PostResponse(ctx, targetPod)
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before))
}
}

type defaultPlugin struct {
picker.RandomPicker
}
Expand Down
70 changes: 70 additions & 0 deletions pkg/epp/scheduling/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,59 @@ func TestSchedulePlugins(t *testing.T) {
}
}

func TestPostResponse(t *testing.T) {
pr1 := &testPostResponse{
NameRes: "pr1",
ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"},
ReceivedResponseHeaders: make(map[string]string),
}

targetPod := k8stypes.NamespacedName{Name: "pod2"}

tests := []struct {
name string
config SchedulerConfig
input []*backendmetrics.FakePodMetrics
responseHeaders map[string]string
wantMutatedHeaders map[string]string
}{
{
name: "Simple postResponse test",
config: SchedulerConfig{
postResponsePlugins: []plugins.PostResponse{pr1},
},
input: []*backendmetrics.FakePodMetrics{
{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}},
{Pod: &backend.Pod{NamespacedName: targetPod}},
},
responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"},
wantMutatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"},
},
}

for _, test := range tests {
scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config)

req := &types.LLMRequest{
Model: "test-model",
Headers: test.responseHeaders,
}

result, err := scheduler.OnResponse(context.Background(), req, targetPod.String())
if err != nil {
t.Errorf("Received an error. Error: %s", err)
}

if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" {
t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff)
}

if diff := cmp.Diff(test.wantMutatedHeaders, result.MutatedHeaders); diff != "" {
t.Errorf("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v", diff)
}
}
}

type fakeDataStore struct {
pods []*backendmetrics.FakePodMetrics
}
Expand Down Expand Up @@ -548,6 +601,23 @@ func (tp *TestPlugin) reset() {
tp.NumOfPickerCandidates = 0
}

type testPostResponse struct {
NameRes string
ReceivedResponseHeaders map[string]string
ExtraHeaders map[string]string
}

func (pr *testPostResponse) Name() string { return pr.NameRes }

func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {
for key, value := range ctx.Req.Headers {
pr.ReceivedResponseHeaders[key] = value
}
for key, value := range pr.ExtraHeaders {
ctx.MutatedHeaders[key] = value
}
}

func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod {
res := []types.Pod{}
for _, pod := range ctx.PodsSnapshot {
Expand Down