Skip to content

Weighted scorers #737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pkg/epp/scheduling/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,22 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"

type SchedulerConfig struct {
preSchedulePlugins []plugins.PreSchedule
scorers []plugins.Scorer
filters []plugins.Filter
scorers map[plugins.Scorer]int // map from scorer to weight
postSchedulePlugins []plugins.PostSchedule
picker plugins.Picker
}

var defPlugin = &defaultPlugin{}

// When the scheduler is initialized with NewScheduler function, this config will be used as default.
// it's possible to call NewSchedulerWithConfig to pass a different argument.

// For build time plugins changes, it's recommended to change the defaultConfig variable in this file.
Copy link
Contributor Author

@nirrozenbaum nirrozenbaum Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PAY ATTENTION.
when someone wants to change GIE default filters/scorers/picker/etc by forking the repo the flow is the following:

  • fork repo
  • add their own plugins in new files.
  • change ONLY THIS FILE to include their set of plugins using the defaultConfig variable.
  • that's it. build epp.

the idea behind this small file is for such users to be able to keep their fork synced without drifting from upstream.
so since they have only this single file change. they can sync the rest of the code to get bug fixes and such.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not to manage two separate files:
one for SchedulerConfig definition which will not be changed in personal fork (since it could be changed in the upstream)
and another file which defines the config instance with all required filters/scorers/picker/...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if SchedulerConfig has changed in upstream (e.g., includes new fields) that means Scheduler itself changed in upstream (cause the scheduler uses the fields from the config).
in such a case it's not possible to keep the previous config in a personal fork and use the scheduler from upstream.

var defaultConfig = &SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{},
filters: []plugins.Filter{defPlugin},
scorers: map[plugins.Scorer]int{},
postSchedulePlugins: []plugins.PostSchedule{},
picker: defPlugin,
}
31 changes: 0 additions & 31 deletions pkg/epp/scheduling/default_config.go

This file was deleted.

6 changes: 2 additions & 4 deletions pkg/epp/scheduling/plugins/filter/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func TestFilter(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
got := test.filter.Filter(ctx, test.input)

opt := cmp.AllowUnexported(types.PodMetrics{})
if diff := cmp.Diff(test.output, got, opt); diff != "" {
if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
})
Expand Down Expand Up @@ -190,8 +189,7 @@ func TestFilterFunc(t *testing.T) {
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
got := test.f(ctx, test.input)

opt := cmp.AllowUnexported(types.PodMetrics{})
if diff := cmp.Diff(test.output, got, opt); diff != "" {
if diff := cmp.Diff(test.output, got); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
})
Expand Down
42 changes: 0 additions & 42 deletions pkg/epp/scheduling/plugins/noop.go

This file was deleted.

22 changes: 18 additions & 4 deletions pkg/epp/scheduling/plugins/picker/random_picker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,32 @@ import (
"fmt"
"math/rand"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

var _ plugins.Picker = &RandomPicker{}

// RandomPicker picks a random pod from the list of candidates.
type RandomPicker struct{}

func (rp *RandomPicker) Name() string {
return "random"
}

func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result {
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods))
i := rand.Intn(len(pods))
return &types.Result{TargetPod: pods[i]}
func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, scoredPods map[types.Pod]float64) *types.Result {
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates", len(scoredPods)))
selectedIndex := rand.Intn(len(scoredPods))
i := 0
var randomPod types.Pod
for pod := range scoredPods {
if selectedIndex == i {
randomPod = pod
break

}
i++
}
return &types.Result{TargetPod: randomPod}
}
7 changes: 4 additions & 3 deletions pkg/epp/scheduling/plugins/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ type Filter interface {
Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod
}

// Scorer defines the interface for scoring pods based on context.
// Scorer defines the interface for scoring a list of pods based on context.
// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score.
type Scorer interface {
Plugin
Score(ctx *types.SchedulingContext, pod types.Pod) float64
Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64
}

// PostSchedule is called by the scheduler after it selects a targetPod for the request.
Expand All @@ -64,7 +65,7 @@ type PostSchedule interface {
// Picker picks the final pod(s) to send the request to.
type Picker interface {
Plugin
Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result
Pick(ctx *types.SchedulingContext, scoredPods map[types.Pod]float64) *types.Result
}

// PostResponse is called by the scheduler after a successful response was sent.
Expand Down
69 changes: 37 additions & 32 deletions pkg/epp/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,18 @@ func NewScheduler(datastore Datastore) *Scheduler {
}

func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Scheduler {
sumOfScorersWeights := 0
for _, weight := range config.scorers {
sumOfScorersWeights += weight
}
scheduler := &Scheduler{
datastore: datastore,
preSchedulePlugins: config.preSchedulePlugins,
scorers: config.scorers,
filters: config.filters,
scorers: config.scorers,
postSchedulePlugins: config.postSchedulePlugins,
picker: config.picker,
sumOfScorersWeights: sumOfScorersWeights,
}

return scheduler
Expand All @@ -88,9 +93,10 @@ type Scheduler struct {
datastore Datastore
preSchedulePlugins []plugins.PreSchedule
filters []plugins.Filter
scorers []plugins.Scorer
scorers map[plugins.Scorer]int // map from scorer to its weight
postSchedulePlugins []plugins.PostSchedule
picker plugins.Picker
sumOfScorersWeights int
}

type Datastore interface {
Expand All @@ -106,21 +112,21 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request.
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot))
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))

s.runPreSchedulePlugins(sCtx)

pods := s.runFilterPlugins(sCtx)
if len(pods) == 0 {
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: "failed to find a target pod"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think its in scope of this PR, but we should think about changing this error code. there's no guarantee that resources are exhausted when there are no pods, just that the filter for the specific request came up with nothing. I worry that sends incorrect signals to the reader.

Copy link
Contributor

@liu-cong liu-cong Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression from b24f948 I believe. Previously this error was returned by the sheddableRequestFilter which knows this is a resource exhausted error.

I would argue to keep the error from the Filter and Scorers, even if we don't need them today. This is more future proof, and addresses issues like this. @nirrozenbaum What do you think about bringing the errors back?
Not asking to do it in this PR. Can be a follow up.

Copy link
Contributor Author

@nirrozenbaum nirrozenbaum Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as of today, we have in main the following filters:
LowQueueFilter, LoRAAffinityFilter, LeastQueueFilter, LeastKVCacheFilter, HasCapacityFilter. In addition to those, we're currently implementing few more filters and scorers in IBM that will be added to GIE when ready.
none of the above uses error.
I think this should give us some hints about the usage of error in filters/scorers/picker.

I'm in favor of being flexible, but I think we should be use-case driven and not try to be future proof, otherwise we may end up implementing things that are not used and/or not relevant.

for example, we changed the Score function to get all pods. this is flexible because it allows batching pods to an external scorer and on the same time it also allows scoring each pod separately.
but this change was use-case driven - it is required since we develop in IBM a KVCache scorer which uses redis as an external service and @kfswain also showed in his py-go that it's useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DropRequest is a use case to return an error. How do we solve that without an error?
My point of being future proof is that it will take a lot of refactoring effort if we add it back later, once we have many more plugins in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we started this discussion on #677. I'll put here my interpretation for this question:
from my understanding, DropRequestFilter == FilterAllPods (please correct me if I'm wrong).
in other words, this is a filter which results in an answer of "no pods passed this filter".
this is not unique to this filter and can happen in any filter (scheduler is extensible now, new filters might be written and used).

first point, we can argue if filter that its result is an empty list is an error. I claim that there was no error in the filter itself, as the filter worked correctly and filtered according to its definition. the result is that we cannot serve the request, but this is not an error with the filter (scheduler returns an error if len(filtered_pods) == 0).

second point, and I'll refer to the code as it was in #677 (since it has changed since then) - this error was used in two places in the code:

to summarize, IMO the DropRequest that was previously used is not a good example of an error usage, since this is not an error. it was used (I think incorrectly) to mark that no pods were left. for this check, we can easily check the length of the returned filter pods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a lot of great points!

If a filter results in 0 pods, this will effectively lead to an error for the scheduler. While it's debatable whether it's an error from the filter perspective, we need to communicate this to the scheduler. Currently the scheduler simply returns a ResourceExhausted error, which seems too broad. IMO the filter should communicate a reason why it filtered out all pods.

How about adding another "filterReason" return value (perhaps an enum), if an error is too strong, so the scheduler can interpret the empty pods and return a proper error to consumers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I agree that ResourceExhausted is not the right error to use here and that should be changed.
we're in some intermediate phase and I'm trying to break down the scheduler refactoring into smaller PRs, to allow easier reviews and merge. I will surely change this in the next PR.
I tend to say this discussion is only around filter. once there is at least 1 pod left after filter part, I really don't think we should abort a scheduling cycle if there is a temp issue with one of the scorers, that scorer can return zero value for all candidate pods until the issue is resolved (and write the issue to log). this is very different from the kube-scheduler framework in this aspect.

"filterReason" is another word for saying error :) if there is no error we won't use the reason.
so if we decide on adding such a thing, it's better to use error which can be wrapped in upper layers errors.

can you give an example of "filterReason" you can think of?
I think writing in the returned error which filter is the one that filtered all pods is useful. this is what I planned in my next PR (maybe I can push this in a separate PR today).
but what reason could be other than - the candidate pods doesn't meet the filter criteria?
at the end, it's a boolean condition.

}

s.runScorerPlugins(sCtx, pods)
// if we got here, there is at least one pod to score
weightedScorePerPod := s.runScorerPlugins(sCtx, pods)

before := time.Now()
res := s.picker.Pick(sCtx, pods)
res := s.picker.Pick(sCtx, weightedScorePerPod)
metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before))
loggerDebug.Info("After running picker plugins", "result", res)
loggerDebug.Info("After running picker plugin", "result", res)

s.runPostSchedulePlugins(sCtx, res)

Expand All @@ -136,15 +142,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
}
}

func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) {
for _, plugin := range s.postSchedulePlugins {
ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name())
before := time.Now()
plugin.PostSchedule(ctx, res)
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
}
}

func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
loggerDebug := ctx.Logger.V(logutil.DEBUG)
filteredPods := ctx.PodsSnapshot
Expand All @@ -163,29 +160,37 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
return filteredPods
}

func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) {
func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
loggerDebug := ctx.Logger.V(logutil.DEBUG)
loggerDebug.Info("Before running score plugins", "pods", pods)
loggerDebug.Info("Before running scorer plugins", "pods", pods)

weightedScorePerPod := make(map[types.Pod]float64, len(pods))
for _, pod := range pods {
score := s.runScorersForPod(ctx, pod)
pod.SetScore(score)
weightedScorePerPod[pod] = float64(0) // initialize weighted score per pod with 0 value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think this is necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if no scorers are configured (for example like today in main), removing this line will cause errors.
try to remove it locally and run unit test to see the error.

}
// Iterate through each scorer in the chain and accumulate the weighted scores.
for scorer, weight := range s.scorers {
loggerDebug.Info("Running scorer", "scorer", scorer.Name())
before := time.Now()
scores := scorer.Score(ctx, pods)
metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before))
for pod, score := range scores { // weight is relative to the sum of weights
weightedScorePerPod[pod] += score * float64(weight) / float64(s.sumOfScorersWeights) // TODO normalize score before multiply with weight
}
loggerDebug.Info("After running scorer", "scorer", scorer.Name())
}
loggerDebug.Info("After running score plugins", "pods", pods)
loggerDebug.Info("After running scorer plugins", "pods", pods)

return weightedScorePerPod
}

// Iterate through each scorer in the chain and accumulate the scores.
func (s *Scheduler) runScorersForPod(ctx *types.SchedulingContext, pod types.Pod) float64 {
logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG)
score := float64(0)
for _, scorer := range s.scorers {
logger.Info("Running scorer", "scorer", scorer.Name())
func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) {
for _, plugin := range s.postSchedulePlugins {
ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name())
before := time.Now()
oneScore := scorer.Score(ctx, pod)
metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before))
score += oneScore
logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score)
plugin.PostSchedule(ctx, res)
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
}
return score
}

type defaultPlugin struct {
Expand Down
Loading