@@ -72,25 +72,23 @@ func NewScheduler(datastore Datastore) *Scheduler {
72
72
}
73
73
74
74
func NewSchedulerWithConfig (datastore Datastore , config * SchedulerConfig ) * Scheduler {
75
- scheduler := & Scheduler {
75
+ return & Scheduler {
76
76
datastore : datastore ,
77
77
preSchedulePlugins : config .preSchedulePlugins ,
78
- scorers : config .scorers ,
79
78
filters : config .filters ,
80
- postSchedulePlugins : config .postSchedulePlugins ,
79
+ scorers : config .scorers ,
81
80
picker : config .picker ,
81
+ postSchedulePlugins : config .postSchedulePlugins ,
82
82
}
83
-
84
- return scheduler
85
83
}
86
84
87
85
type Scheduler struct {
88
86
datastore Datastore
89
87
preSchedulePlugins []plugins.PreSchedule
90
88
filters []plugins.Filter
91
- scorers []plugins.Scorer
92
- postSchedulePlugins []plugins.PostSchedule
89
+ scorers map [plugins.Scorer ]int // map from scorer to its weight
93
90
picker plugins.Picker
91
+ postSchedulePlugins []plugins.PostSchedule
94
92
}
95
93
96
94
type Datastore interface {
@@ -106,25 +104,22 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
106
104
// 1. Reduce concurrent access to the datastore.
107
105
// 2. Ensure consistent data during the scheduling operation of a request.
108
106
sCtx := types .NewSchedulingContext (ctx , req , types .ToSchedulerPodMetrics (s .datastore .PodGetAll ()))
109
- loggerDebug .Info (fmt .Sprintf ("Scheduling a request. Metrics: %+v" , sCtx .PodsSnapshot ))
107
+ loggerDebug .Info (fmt .Sprintf ("Scheduling a request, Metrics: %+v" , sCtx .PodsSnapshot ))
110
108
111
109
s .runPreSchedulePlugins (sCtx )
112
110
113
111
pods := s .runFilterPlugins (sCtx )
114
112
if len (pods ) == 0 {
115
113
return nil , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : "failed to find a target pod" }
116
114
}
115
+ // if we got here, there is at least one pod to score
116
+ weightedScorePerPod := s .runScorerPlugins (sCtx , pods )
117
117
118
- s .runScorerPlugins (sCtx , pods )
119
-
120
- before := time .Now ()
121
- res := s .picker .Pick (sCtx , pods )
122
- metrics .RecordSchedulerPluginProcessingLatency (plugins .PickerPluginType , s .picker .Name (), time .Since (before ))
123
- loggerDebug .Info ("After running picker plugins" , "result" , res )
118
+ result := s .runPickerPlugin (sCtx , weightedScorePerPod )
124
119
125
- s .runPostSchedulePlugins (sCtx , res )
120
+ s .runPostSchedulePlugins (sCtx , result )
126
121
127
- return res , nil
122
+ return result , nil
128
123
}
129
124
130
125
func (s * Scheduler ) runPreSchedulePlugins (ctx * types.SchedulingContext ) {
@@ -136,15 +131,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
136
131
}
137
132
}
138
133
139
- func (s * Scheduler ) runPostSchedulePlugins (ctx * types.SchedulingContext , res * types.Result ) {
140
- for _ , plugin := range s .postSchedulePlugins {
141
- ctx .Logger .V (logutil .DEBUG ).Info ("Running post-schedule plugin" , "plugin" , plugin .Name ())
142
- before := time .Now ()
143
- plugin .PostSchedule (ctx , res )
144
- metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
145
- }
146
- }
147
-
148
134
func (s * Scheduler ) runFilterPlugins (ctx * types.SchedulingContext ) []types.Pod {
149
135
loggerDebug := ctx .Logger .V (logutil .DEBUG )
150
136
filteredPods := ctx .PodsSnapshot
@@ -160,32 +146,60 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
160
146
break
161
147
}
162
148
}
149
+ loggerDebug .Info ("After running filter plugins" )
150
+
163
151
return filteredPods
164
152
}
165
153
166
- func (s * Scheduler ) runScorerPlugins (ctx * types.SchedulingContext , pods []types.Pod ) {
154
+ func (s * Scheduler ) runScorerPlugins (ctx * types.SchedulingContext , pods []types.Pod ) map [types. Pod ] float64 {
167
155
loggerDebug := ctx .Logger .V (logutil .DEBUG )
168
- loggerDebug .Info ("Before running score plugins" , "pods" , pods )
156
+ loggerDebug .Info ("Before running scorer plugins" , "pods" , pods )
157
+
158
+ weightedScorePerPod := make (map [types.Pod ]float64 , len (pods ))
169
159
for _ , pod := range pods {
170
- score := s .runScorersForPod (ctx , pod )
171
- pod .SetScore (score )
160
+ weightedScorePerPod [pod ] = float64 (0 ) // initialize weighted score per pod with 0 value
161
+ }
162
+ // Iterate through each scorer in the chain and accumulate the weighted scores.
163
+ for scorer , weight := range s .scorers {
164
+ loggerDebug .Info ("Running scorer" , "scorer" , scorer .Name ())
165
+ before := time .Now ()
166
+ scores := scorer .Score (ctx , pods )
167
+ metrics .RecordSchedulerPluginProcessingLatency (plugins .ScorerPluginType , scorer .Name (), time .Since (before ))
168
+ for pod , score := range scores { // weight is relative to the sum of weights
169
+ weightedScorePerPod [pod ] += score * float64 (weight ) // TODO normalize score before multiply with weight
170
+ }
171
+ loggerDebug .Info ("After running scorer" , "scorer" , scorer .Name ())
172
+ }
173
+ loggerDebug .Info ("After running scorer plugins" )
174
+
175
+ return weightedScorePerPod
176
+ }
177
+
178
+ func (s * Scheduler ) runPickerPlugin (ctx * types.SchedulingContext , weightedScorePerPod map [types.Pod ]float64 ) * types.Result {
179
+ loggerDebug := ctx .Logger .V (logutil .DEBUG )
180
+ scoredPods := make ([]* types.ScoredPod , len (weightedScorePerPod ))
181
+ i := 0
182
+ for pod , score := range weightedScorePerPod {
183
+ scoredPods [i ] = & types.ScoredPod {Pod : pod , Score : score }
184
+ i ++
172
185
}
173
- loggerDebug .Info ("After running score plugins" , "pods" , pods )
186
+
187
+ loggerDebug .Info ("Before running picker plugin" , "pods" , weightedScorePerPod )
188
+ before := time .Now ()
189
+ result := s .picker .Pick (ctx , scoredPods )
190
+ metrics .RecordSchedulerPluginProcessingLatency (plugins .PickerPluginType , s .picker .Name (), time .Since (before ))
191
+ loggerDebug .Info ("After running picker plugin" , "result" , result )
192
+
193
+ return result
174
194
}
175
195
176
- // Iterate through each scorer in the chain and accumulate the scores.
177
- func (s * Scheduler ) runScorersForPod (ctx * types.SchedulingContext , pod types.Pod ) float64 {
178
- logger := ctx .Logger .WithValues ("pod" , pod .GetPod ().NamespacedName ).V (logutil .DEBUG )
179
- score := float64 (0 )
180
- for _ , scorer := range s .scorers {
181
- logger .Info ("Running scorer" , "scorer" , scorer .Name ())
196
+ func (s * Scheduler ) runPostSchedulePlugins (ctx * types.SchedulingContext , res * types.Result ) {
197
+ for _ , plugin := range s .postSchedulePlugins {
198
+ ctx .Logger .V (logutil .DEBUG ).Info ("Running post-schedule plugin" , "plugin" , plugin .Name ())
182
199
before := time .Now ()
183
- oneScore := scorer .Score (ctx , pod )
184
- metrics .RecordSchedulerPluginProcessingLatency (plugins .ScorerPluginType , scorer .Name (), time .Since (before ))
185
- score += oneScore
186
- logger .Info ("After scorer" , "scorer" , scorer .Name (), "score" , oneScore , "total score" , score )
200
+ plugin .PostSchedule (ctx , res )
201
+ metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
187
202
}
188
- return score
189
203
}
190
204
191
205
type defaultPlugin struct {
0 commit comments