Skip to content

Commit 855436e

Browse files
authored
Weighted scorers (#737)
* removed unused noop plugin Signed-off-by: Nir Rozenbaum <[email protected]> * more scheduler refactoring Signed-off-by: Nir Rozenbaum <[email protected]> * more refactoring Signed-off-by: Nir Rozenbaum <[email protected]> * added weights to scorers and calculating weighted score Signed-off-by: Nir Rozenbaum <[email protected]> * addressed code review comments Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent e845173 commit 855436e

File tree

9 files changed

+190
-189
lines changed

9 files changed

+190
-189
lines changed

pkg/epp/scheduling/config.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,22 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
2020

2121
type SchedulerConfig struct {
2222
preSchedulePlugins []plugins.PreSchedule
23-
scorers []plugins.Scorer
2423
filters []plugins.Filter
25-
postSchedulePlugins []plugins.PostSchedule
24+
scorers map[plugins.Scorer]int // map from scorer to weight
2625
picker plugins.Picker
26+
postSchedulePlugins []plugins.PostSchedule
27+
}
28+
29+
var defPlugin = &defaultPlugin{}
30+
31+
// When the scheduler is initialized with NewScheduler function, this config will be used as default.
32+
// it's possible to call NewSchedulerWithConfig to pass a different argument.
33+
34+
// For build time plugins changes, it's recommended to change the defaultConfig variable in this file.
35+
var defaultConfig = &SchedulerConfig{
36+
preSchedulePlugins: []plugins.PreSchedule{},
37+
filters: []plugins.Filter{defPlugin},
38+
scorers: map[plugins.Scorer]int{},
39+
picker: defPlugin,
40+
postSchedulePlugins: []plugins.PostSchedule{},
2741
}

pkg/epp/scheduling/default_config.go

Lines changed: 0 additions & 31 deletions
This file was deleted.

pkg/epp/scheduling/plugins/filter/filter_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ func TestFilter(t *testing.T) {
5454
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
5555
got := test.filter.Filter(ctx, test.input)
5656

57-
opt := cmp.AllowUnexported(types.PodMetrics{})
58-
if diff := cmp.Diff(test.output, got, opt); diff != "" {
57+
if diff := cmp.Diff(test.output, got); diff != "" {
5958
t.Errorf("Unexpected output (-want +got): %v", diff)
6059
}
6160
})
@@ -190,8 +189,7 @@ func TestFilterFunc(t *testing.T) {
190189
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
191190
got := test.f(ctx, test.input)
192191

193-
opt := cmp.AllowUnexported(types.PodMetrics{})
194-
if diff := cmp.Diff(test.output, got, opt); diff != "" {
192+
if diff := cmp.Diff(test.output, got); diff != "" {
195193
t.Errorf("Unexpected output (-want +got): %v", diff)
196194
}
197195
})

pkg/epp/scheduling/plugins/noop.go

Lines changed: 0 additions & 42 deletions
This file was deleted.

pkg/epp/scheduling/plugins/picker/random_picker.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ import (
2020
"fmt"
2121
"math/rand"
2222

23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
2324
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2425
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2526
)
2627

28+
var _ plugins.Picker = &RandomPicker{}
29+
30+
// RandomPicker picks a random pod from the list of candidates.
2731
type RandomPicker struct{}
2832

2933
func (rp *RandomPicker) Name() string {
3034
return "random"
3135
}
3236

33-
func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result {
34-
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods))
35-
i := rand.Intn(len(pods))
36-
return &types.Result{TargetPod: pods[i]}
37+
func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result {
38+
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(scoredPods), scoredPods))
39+
i := rand.Intn(len(scoredPods))
40+
return &types.Result{TargetPod: scoredPods[i].Pod}
3741
}

pkg/epp/scheduling/plugins/plugins.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,23 @@ type Filter interface {
4949
Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod
5050
}
5151

52-
// Scorer defines the interface for scoring pods based on context.
52+
// Scorer defines the interface for scoring a list of pods based on context.
53+
// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score.
5354
type Scorer interface {
5455
Plugin
55-
Score(ctx *types.SchedulingContext, pod types.Pod) float64
56+
Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64
5657
}
5758

58-
// PostSchedule is called by the scheduler after it selects a targetPod for the request.
59-
type PostSchedule interface {
59+
// Picker picks the final pod(s) to send the request to.
60+
type Picker interface {
6061
Plugin
61-
PostSchedule(ctx *types.SchedulingContext, res *types.Result)
62+
Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result
6263
}
6364

64-
// Picker picks the final pod(s) to send the request to.
65-
type Picker interface {
65+
// PostSchedule is called by the scheduler after it selects a targetPod for the request.
66+
type PostSchedule interface {
6667
Plugin
67-
Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result
68+
PostSchedule(ctx *types.SchedulingContext, res *types.Result)
6869
}
6970

7071
// PostResponse is called by the scheduler after a successful response was sent.

pkg/epp/scheduling/scheduler.go

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,23 @@ func NewScheduler(datastore Datastore) *Scheduler {
7272
}
7373

7474
func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Scheduler {
75-
scheduler := &Scheduler{
75+
return &Scheduler{
7676
datastore: datastore,
7777
preSchedulePlugins: config.preSchedulePlugins,
78-
scorers: config.scorers,
7978
filters: config.filters,
80-
postSchedulePlugins: config.postSchedulePlugins,
79+
scorers: config.scorers,
8180
picker: config.picker,
81+
postSchedulePlugins: config.postSchedulePlugins,
8282
}
83-
84-
return scheduler
8583
}
8684

8785
type Scheduler struct {
8886
datastore Datastore
8987
preSchedulePlugins []plugins.PreSchedule
9088
filters []plugins.Filter
91-
scorers []plugins.Scorer
92-
postSchedulePlugins []plugins.PostSchedule
89+
scorers map[plugins.Scorer]int // map from scorer to its weight
9390
picker plugins.Picker
91+
postSchedulePlugins []plugins.PostSchedule
9492
}
9593

9694
type Datastore interface {
@@ -106,25 +104,22 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
106104
// 1. Reduce concurrent access to the datastore.
107105
// 2. Ensure consistent data during the scheduling operation of a request.
108106
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
109-
loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot))
107+
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))
110108

111109
s.runPreSchedulePlugins(sCtx)
112110

113111
pods := s.runFilterPlugins(sCtx)
114112
if len(pods) == 0 {
115113
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: "failed to find a target pod"}
116114
}
115+
// if we got here, there is at least one pod to score
116+
weightedScorePerPod := s.runScorerPlugins(sCtx, pods)
117117

118-
s.runScorerPlugins(sCtx, pods)
119-
120-
before := time.Now()
121-
res := s.picker.Pick(sCtx, pods)
122-
metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before))
123-
loggerDebug.Info("After running picker plugins", "result", res)
118+
result := s.runPickerPlugin(sCtx, weightedScorePerPod)
124119

125-
s.runPostSchedulePlugins(sCtx, res)
120+
s.runPostSchedulePlugins(sCtx, result)
126121

127-
return res, nil
122+
return result, nil
128123
}
129124

130125
func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
@@ -136,15 +131,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
136131
}
137132
}
138133

139-
func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) {
140-
for _, plugin := range s.postSchedulePlugins {
141-
ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name())
142-
before := time.Now()
143-
plugin.PostSchedule(ctx, res)
144-
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
145-
}
146-
}
147-
148134
func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
149135
loggerDebug := ctx.Logger.V(logutil.DEBUG)
150136
filteredPods := ctx.PodsSnapshot
@@ -160,32 +146,60 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
160146
break
161147
}
162148
}
149+
loggerDebug.Info("After running filter plugins")
150+
163151
return filteredPods
164152
}
165153

166-
func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) {
154+
func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
167155
loggerDebug := ctx.Logger.V(logutil.DEBUG)
168-
loggerDebug.Info("Before running score plugins", "pods", pods)
156+
loggerDebug.Info("Before running scorer plugins", "pods", pods)
157+
158+
weightedScorePerPod := make(map[types.Pod]float64, len(pods))
169159
for _, pod := range pods {
170-
score := s.runScorersForPod(ctx, pod)
171-
pod.SetScore(score)
160+
weightedScorePerPod[pod] = float64(0) // initialize weighted score per pod with 0 value
161+
}
162+
// Iterate through each scorer in the chain and accumulate the weighted scores.
163+
for scorer, weight := range s.scorers {
164+
loggerDebug.Info("Running scorer", "scorer", scorer.Name())
165+
before := time.Now()
166+
scores := scorer.Score(ctx, pods)
167+
metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before))
168+
for pod, score := range scores { // weight is relative to the sum of weights
169+
weightedScorePerPod[pod] += score * float64(weight) // TODO normalize score before multiply with weight
170+
}
171+
loggerDebug.Info("After running scorer", "scorer", scorer.Name())
172+
}
173+
loggerDebug.Info("After running scorer plugins")
174+
175+
return weightedScorePerPod
176+
}
177+
178+
func (s *Scheduler) runPickerPlugin(ctx *types.SchedulingContext, weightedScorePerPod map[types.Pod]float64) *types.Result {
179+
loggerDebug := ctx.Logger.V(logutil.DEBUG)
180+
scoredPods := make([]*types.ScoredPod, len(weightedScorePerPod))
181+
i := 0
182+
for pod, score := range weightedScorePerPod {
183+
scoredPods[i] = &types.ScoredPod{Pod: pod, Score: score}
184+
i++
172185
}
173-
loggerDebug.Info("After running score plugins", "pods", pods)
186+
187+
loggerDebug.Info("Before running picker plugin", "pods", weightedScorePerPod)
188+
before := time.Now()
189+
result := s.picker.Pick(ctx, scoredPods)
190+
metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before))
191+
loggerDebug.Info("After running picker plugin", "result", result)
192+
193+
return result
174194
}
175195

176-
// Iterate through each scorer in the chain and accumulate the scores.
177-
func (s *Scheduler) runScorersForPod(ctx *types.SchedulingContext, pod types.Pod) float64 {
178-
logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG)
179-
score := float64(0)
180-
for _, scorer := range s.scorers {
181-
logger.Info("Running scorer", "scorer", scorer.Name())
196+
func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) {
197+
for _, plugin := range s.postSchedulePlugins {
198+
ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name())
182199
before := time.Now()
183-
oneScore := scorer.Score(ctx, pod)
184-
metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before))
185-
score += oneScore
186-
logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score)
200+
plugin.PostSchedule(ctx, res)
201+
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
187202
}
188-
return score
189203
}
190204

191205
type defaultPlugin struct {

0 commit comments

Comments
 (0)