Skip to content

Commit d9cb903

Browse files
nirrozenbaumrlakhtakia
authored andcommitted
Weighted scorers (kubernetes-sigs#737)
* removed unused noop plugin Signed-off-by: Nir Rozenbaum <[email protected]> * more scheduler refactoring Signed-off-by: Nir Rozenbaum <[email protected]> * more refactoring Signed-off-by: Nir Rozenbaum <[email protected]> * added weights to scorers and calculating weighted score Signed-off-by: Nir Rozenbaum <[email protected]> * addressed code review comments Signed-off-by: Nir Rozenbaum <[email protected]> --------- Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent c6ff494 commit d9cb903

File tree

8 files changed

+191
-219
lines changed

8 files changed

+191
-219
lines changed

pkg/epp/scheduling/config.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,22 @@ import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
2020

2121
type SchedulerConfig struct {
2222
preSchedulePlugins []plugins.PreSchedule
23-
scorers []plugins.Scorer
2423
filters []plugins.Filter
25-
postSchedulePlugins []plugins.PostSchedule
24+
scorers map[plugins.Scorer]int // map from scorer to weight
2625
picker plugins.Picker
26+
postSchedulePlugins []plugins.PostSchedule
27+
}
28+
29+
var defPlugin = &defaultPlugin{}
30+
31+
// When the scheduler is initialized with NewScheduler function, this config will be used as default.
32+
// it's possible to call NewSchedulerWithConfig to pass a different argument.
33+
34+
// For build time plugins changes, it's recommended to change the defaultConfig variable in this file.
35+
var defaultConfig = &SchedulerConfig{
36+
preSchedulePlugins: []plugins.PreSchedule{},
37+
filters: []plugins.Filter{defPlugin},
38+
scorers: map[plugins.Scorer]int{},
39+
picker: defPlugin,
40+
postSchedulePlugins: []plugins.PostSchedule{},
2741
}

pkg/epp/scheduling/default_config.go

Lines changed: 0 additions & 31 deletions
This file was deleted.

pkg/epp/scheduling/plugins/noop.go

Lines changed: 0 additions & 42 deletions
This file was deleted.

pkg/epp/scheduling/plugins/picker/random_picker.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ import (
2020
"fmt"
2121
"math/rand"
2222

23+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
2324
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2425
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2526
)
2627

28+
var _ plugins.Picker = &RandomPicker{}
29+
30+
// RandomPicker picks a random pod from the list of candidates.
2731
type RandomPicker struct{}
2832

2933
func (rp *RandomPicker) Name() string {
3034
return "random"
3135
}
3236

33-
func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result {
34-
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods))
35-
i := rand.Intn(len(pods))
36-
return &types.Result{TargetPod: pods[i]}
37+
func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result {
38+
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(scoredPods), scoredPods))
39+
i := rand.Intn(len(scoredPods))
40+
return &types.Result{TargetPod: scoredPods[i].Pod}
3741
}

pkg/epp/scheduling/plugins/plugins.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,23 @@ type Filter interface {
4949
Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod
5050
}
5151

52-
// Scorer defines the interface for scoring pods based on context.
52+
// Scorer defines the interface for scoring a list of pods based on context.
53+
// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score.
5354
type Scorer interface {
5455
Plugin
55-
Score(ctx *types.SchedulingContext, pod types.Pod) float64
56+
Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64
5657
}
5758

58-
// PostSchedule is called by the scheduler after it selects a targetPod for the request.
59-
type PostSchedule interface {
59+
// Picker picks the final pod(s) to send the request to.
60+
type Picker interface {
6061
Plugin
61-
PostSchedule(ctx *types.SchedulingContext, res *types.Result)
62+
Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result
6263
}
6364

64-
// Picker picks the final pod(s) to send the request to.
65-
type Picker interface {
65+
// PostSchedule is called by the scheduler after it selects a targetPod for the request.
66+
type PostSchedule interface {
6667
Plugin
67-
Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result
68+
PostSchedule(ctx *types.SchedulingContext, res *types.Result)
6869
}
6970

7071
// PostResponse is called by the scheduler after a successful response was sent.

pkg/epp/scheduling/scheduler.go

Lines changed: 58 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -72,25 +72,23 @@ func NewScheduler(datastore Datastore) *Scheduler {
7272
}
7373

7474
func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Scheduler {
75-
scheduler := &Scheduler{
75+
return &Scheduler{
7676
datastore: datastore,
7777
preSchedulePlugins: config.preSchedulePlugins,
78-
scorers: config.scorers,
7978
filters: config.filters,
80-
postSchedulePlugins: config.postSchedulePlugins,
79+
scorers: config.scorers,
8180
picker: config.picker,
81+
postSchedulePlugins: config.postSchedulePlugins,
8282
}
83-
84-
return scheduler
8583
}
8684

8785
type Scheduler struct {
8886
datastore Datastore
8987
preSchedulePlugins []plugins.PreSchedule
9088
filters []plugins.Filter
91-
scorers []plugins.Scorer
92-
postSchedulePlugins []plugins.PostSchedule
89+
scorers map[plugins.Scorer]int // map from scorer to its weight
9390
picker plugins.Picker
91+
postSchedulePlugins []plugins.PostSchedule
9492
}
9593

9694
type Datastore interface {
@@ -106,25 +104,22 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
106104
// 1. Reduce concurrent access to the datastore.
107105
// 2. Ensure consistent data during the scheduling operation of a request.
108106
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
109-
loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot))
107+
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))
110108

111109
s.runPreSchedulePlugins(sCtx)
112110

113111
pods := s.runFilterPlugins(sCtx)
114112
if len(pods) == 0 {
115113
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: "failed to find a target pod"}
116114
}
115+
// if we got here, there is at least one pod to score
116+
weightedScorePerPod := s.runScorerPlugins(sCtx, pods)
117117

118-
s.runScorerPlugins(sCtx, pods)
118+
result := s.runPickerPlugin(sCtx, weightedScorePerPod)
119119

120-
before := time.Now()
121-
res := s.picker.Pick(sCtx, pods)
122-
metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before))
123-
loggerDebug.Info("After running picker plugins", "result", res)
124-
125-
s.runPostSchedulePlugins(sCtx, res)
120+
s.runPostSchedulePlugins(sCtx, result)
126121

127-
return res, nil
122+
return result, nil
128123
}
129124

130125
func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
@@ -136,22 +131,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
136131
}
137132
}
138133

139-
func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) {
140-
for _, plugin := range s.postSchedulePlugins {
141-
ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name())
142-
before := time.Now()
143-
plugin.PostSchedule(ctx, res)
144-
<<<<<<< HEAD
145-
<<<<<<< HEAD
146-
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
147-
}
148-
}
149-
=======
150-
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
151-
}
152-
}
153-
154-
>>>>>>> b24f948 (scheduler refactoring (#730))
155134
func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
156135
loggerDebug := ctx.Logger.V(logutil.DEBUG)
157136
filteredPods := ctx.PodsSnapshot
@@ -167,70 +146,74 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
167146
break
168147
}
169148
}
149+
loggerDebug.Info("After running filter plugins")
150+
170151
return filteredPods
171152
}
172153

173-
func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) {
154+
func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
174155
loggerDebug := ctx.Logger.V(logutil.DEBUG)
175-
loggerDebug.Info("Before running score plugins", "pods", pods)
156+
loggerDebug.Info("Before running scorer plugins", "pods", pods)
157+
158+
weightedScorePerPod := make(map[types.Pod]float64, len(pods))
176159
for _, pod := range pods {
177-
score := s.runScorersForPod(ctx, pod)
178-
pod.SetScore(score)
160+
weightedScorePerPod[pod] = float64(0) // initialize weighted score per pod with 0 value
161+
}
162+
// Iterate through each scorer in the chain and accumulate the weighted scores.
163+
for scorer, weight := range s.scorers {
164+
loggerDebug.Info("Running scorer", "scorer", scorer.Name())
165+
before := time.Now()
166+
scores := scorer.Score(ctx, pods)
167+
metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before))
168+
for pod, score := range scores { // weight is relative to the sum of weights
169+
weightedScorePerPod[pod] += score * float64(weight) // TODO normalize score before multiply with weight
170+
}
171+
loggerDebug.Info("After running scorer", "scorer", scorer.Name())
179172
}
180-
loggerDebug.Info("After running score plugins", "pods", pods)
173+
loggerDebug.Info("After running scorer plugins")
174+
175+
return weightedScorePerPod
181176
}
182177

183-
// Iterate through each scorer in the chain and accumulate the scores.
184-
func (s *Scheduler) runScorersForPod(ctx *types.SchedulingContext, pod types.Pod) float64 {
185-
logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG)
186-
score := float64(0)
187-
for _, scorer := range s.scorers {
188-
logger.Info("Running scorer", "scorer", scorer.Name())
178+
func (s *Scheduler) runPickerPlugin(ctx *types.SchedulingContext, weightedScorePerPod map[types.Pod]float64) *types.Result {
179+
loggerDebug := ctx.Logger.V(logutil.DEBUG)
180+
scoredPods := make([]*types.ScoredPod, len(weightedScorePerPod))
181+
i := 0
182+
for pod, score := range weightedScorePerPod {
183+
scoredPods[i] = &types.ScoredPod{Pod: pod, Score: score}
184+
i++
185+
}
186+
187+
loggerDebug.Info("Before running picker plugin", "pods", weightedScorePerPod)
188+
before := time.Now()
189+
result := s.picker.Pick(ctx, scoredPods)
190+
metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before))
191+
loggerDebug.Info("After running picker plugin", "result", result)
192+
193+
return result
194+
}
195+
196+
func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) {
197+
for _, plugin := range s.postSchedulePlugins {
198+
ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name())
189199
before := time.Now()
190-
oneScore := scorer.Score(ctx, pod)
191-
metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before))
192-
score += oneScore
193-
logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score)
200+
plugin.PostSchedule(ctx, res)
201+
metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before))
194202
}
195-
return score
196203
}
197204

198205
type defaultPlugin struct {
199206
picker.RandomPicker
200-
<<<<<<< HEAD
201-
=======
202-
metrics.RecordSchedulerPluginProcessingLatency(types.PostSchedulePluginType, plugin.Name(), time.Since(before))
203-
}
204-
}
205-
=======
206-
>>>>>>> b24f948 (scheduler refactoring (#730))
207207
}
208208

209-
<<<<<<< HEAD
210-
func (p *defaultPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
211-
if ctx.Req.Critical {
212-
return lowLatencyFilter.Filter(ctx, pods)
213-
}
209+
func (p *defaultPlugin) Name() string {
210+
return "DefaultPlugin"
211+
}
214212

215-
<<<<<<< HEAD
216-
return sheddableRequestFilter.Filter(ctx, pods)
217-
=======
218-
func (p *defaultPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
219-
req := ctx.Req
220-
var filter types.Filter
221-
if req.Critical {
222-
filter = lowLatencyFilter
223-
} else {
224-
filter = sheddableRequestFilter
225-
}
226-
return filter.Filter(ctx, pods)
227-
>>>>>>> 45209f6 (Refactor scheduler to run plugins (#677))
228-
=======
229213
func (p *defaultPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
230214
if ctx.Req.Critical {
231215
return lowLatencyFilter.Filter(ctx, pods)
232216
}
233217

234218
return sheddableRequestFilter.Filter(ctx, pods)
235-
>>>>>>> b24f948 (scheduler refactoring (#730))
236219
}

0 commit comments

Comments
 (0)