@@ -72,25 +72,23 @@ func NewScheduler(datastore Datastore) *Scheduler {
72
72
}
73
73
74
74
func NewSchedulerWithConfig (datastore Datastore , config * SchedulerConfig ) * Scheduler {
75
- scheduler := & Scheduler {
75
+ return & Scheduler {
76
76
datastore : datastore ,
77
77
preSchedulePlugins : config .preSchedulePlugins ,
78
- scorers : config .scorers ,
79
78
filters : config .filters ,
80
- postSchedulePlugins : config .postSchedulePlugins ,
79
+ scorers : config .scorers ,
81
80
picker : config .picker ,
81
+ postSchedulePlugins : config .postSchedulePlugins ,
82
82
}
83
-
84
- return scheduler
85
83
}
86
84
87
85
type Scheduler struct {
88
86
datastore Datastore
89
87
preSchedulePlugins []plugins.PreSchedule
90
88
filters []plugins.Filter
91
- scorers []plugins.Scorer
92
- postSchedulePlugins []plugins.PostSchedule
89
+ scorers map [plugins.Scorer ]int // map from scorer to its weight
93
90
picker plugins.Picker
91
+ postSchedulePlugins []plugins.PostSchedule
94
92
}
95
93
96
94
type Datastore interface {
@@ -106,25 +104,22 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
106
104
// 1. Reduce concurrent access to the datastore.
107
105
// 2. Ensure consistent data during the scheduling operation of a request.
108
106
sCtx := types .NewSchedulingContext (ctx , req , types .ToSchedulerPodMetrics (s .datastore .PodGetAll ()))
109
- loggerDebug .Info (fmt .Sprintf ("Scheduling a request. Metrics: %+v" , sCtx .PodsSnapshot ))
107
+ loggerDebug .Info (fmt .Sprintf ("Scheduling a request, Metrics: %+v" , sCtx .PodsSnapshot ))
110
108
111
109
s .runPreSchedulePlugins (sCtx )
112
110
113
111
pods := s .runFilterPlugins (sCtx )
114
112
if len (pods ) == 0 {
115
113
return nil , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : "failed to find a target pod" }
116
114
}
115
+ // if we got here, there is at least one pod to score
116
+ weightedScorePerPod := s .runScorerPlugins (sCtx , pods )
117
117
118
- s . runScorerPlugins (sCtx , pods )
118
+ result := s . runPickerPlugin (sCtx , weightedScorePerPod )
119
119
120
- before := time .Now ()
121
- res := s .picker .Pick (sCtx , pods )
122
- metrics .RecordSchedulerPluginProcessingLatency (plugins .PickerPluginType , s .picker .Name (), time .Since (before ))
123
- loggerDebug .Info ("After running picker plugins" , "result" , res )
124
-
125
- s .runPostSchedulePlugins (sCtx , res )
120
+ s .runPostSchedulePlugins (sCtx , result )
126
121
127
- return res , nil
122
+ return result , nil
128
123
}
129
124
130
125
func (s * Scheduler ) runPreSchedulePlugins (ctx * types.SchedulingContext ) {
@@ -136,22 +131,6 @@ func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) {
136
131
}
137
132
}
138
133
139
- func (s * Scheduler ) runPostSchedulePlugins (ctx * types.SchedulingContext , res * types.Result ) {
140
- for _ , plugin := range s .postSchedulePlugins {
141
- ctx .Logger .V (logutil .DEBUG ).Info ("Running post-schedule plugin" , "plugin" , plugin .Name ())
142
- before := time .Now ()
143
- plugin .PostSchedule (ctx , res )
144
- << << << < HEAD
145
- << << << < HEAD
146
- metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
147
- }
148
- }
149
- == == == =
150
- metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
151
- }
152
- }
153
-
154
- >> >> >> > b24f948 (scheduler refactoring (#730 ))
155
134
func (s * Scheduler ) runFilterPlugins (ctx * types.SchedulingContext ) []types.Pod {
156
135
loggerDebug := ctx .Logger .V (logutil .DEBUG )
157
136
filteredPods := ctx .PodsSnapshot
@@ -167,70 +146,74 @@ func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod {
167
146
break
168
147
}
169
148
}
149
+ loggerDebug .Info ("After running filter plugins" )
150
+
170
151
return filteredPods
171
152
}
172
153
173
- func (s * Scheduler ) runScorerPlugins (ctx * types.SchedulingContext , pods []types.Pod ) {
154
+ func (s * Scheduler ) runScorerPlugins (ctx * types.SchedulingContext , pods []types.Pod ) map [types. Pod ] float64 {
174
155
loggerDebug := ctx .Logger .V (logutil .DEBUG )
175
- loggerDebug .Info ("Before running score plugins" , "pods" , pods )
156
+ loggerDebug .Info ("Before running scorer plugins" , "pods" , pods )
157
+
158
+ weightedScorePerPod := make (map [types.Pod ]float64 , len (pods ))
176
159
for _ , pod := range pods {
177
- score := s .runScorersForPod (ctx , pod )
178
- pod .SetScore (score )
160
+ weightedScorePerPod [pod ] = float64 (0 ) // initialize weighted score per pod with 0 value
161
+ }
162
+ // Iterate through each scorer in the chain and accumulate the weighted scores.
163
+ for scorer , weight := range s .scorers {
164
+ loggerDebug .Info ("Running scorer" , "scorer" , scorer .Name ())
165
+ before := time .Now ()
166
+ scores := scorer .Score (ctx , pods )
167
+ metrics .RecordSchedulerPluginProcessingLatency (plugins .ScorerPluginType , scorer .Name (), time .Since (before ))
168
+ for pod , score := range scores { // weight is relative to the sum of weights
169
+ weightedScorePerPod [pod ] += score * float64 (weight ) // TODO normalize score before multiply with weight
170
+ }
171
+ loggerDebug .Info ("After running scorer" , "scorer" , scorer .Name ())
179
172
}
180
- loggerDebug .Info ("After running score plugins" , "pods" , pods )
173
+ loggerDebug .Info ("After running scorer plugins" )
174
+
175
+ return weightedScorePerPod
181
176
}
182
177
183
- // Iterate through each scorer in the chain and accumulate the scores.
184
- func (s * Scheduler ) runScorersForPod (ctx * types.SchedulingContext , pod types.Pod ) float64 {
185
- logger := ctx .Logger .WithValues ("pod" , pod .GetPod ().NamespacedName ).V (logutil .DEBUG )
186
- score := float64 (0 )
187
- for _ , scorer := range s .scorers {
188
- logger .Info ("Running scorer" , "scorer" , scorer .Name ())
178
+ func (s * Scheduler ) runPickerPlugin (ctx * types.SchedulingContext , weightedScorePerPod map [types.Pod ]float64 ) * types.Result {
179
+ loggerDebug := ctx .Logger .V (logutil .DEBUG )
180
+ scoredPods := make ([]* types.ScoredPod , len (weightedScorePerPod ))
181
+ i := 0
182
+ for pod , score := range weightedScorePerPod {
183
+ scoredPods [i ] = & types.ScoredPod {Pod : pod , Score : score }
184
+ i ++
185
+ }
186
+
187
+ loggerDebug .Info ("Before running picker plugin" , "pods" , weightedScorePerPod )
188
+ before := time .Now ()
189
+ result := s .picker .Pick (ctx , scoredPods )
190
+ metrics .RecordSchedulerPluginProcessingLatency (plugins .PickerPluginType , s .picker .Name (), time .Since (before ))
191
+ loggerDebug .Info ("After running picker plugin" , "result" , result )
192
+
193
+ return result
194
+ }
195
+
196
+ func (s * Scheduler ) runPostSchedulePlugins (ctx * types.SchedulingContext , res * types.Result ) {
197
+ for _ , plugin := range s .postSchedulePlugins {
198
+ ctx .Logger .V (logutil .DEBUG ).Info ("Running post-schedule plugin" , "plugin" , plugin .Name ())
189
199
before := time .Now ()
190
- oneScore := scorer .Score (ctx , pod )
191
- metrics .RecordSchedulerPluginProcessingLatency (plugins .ScorerPluginType , scorer .Name (), time .Since (before ))
192
- score += oneScore
193
- logger .Info ("After scorer" , "scorer" , scorer .Name (), "score" , oneScore , "total score" , score )
200
+ plugin .PostSchedule (ctx , res )
201
+ metrics .RecordSchedulerPluginProcessingLatency (plugins .PostSchedulePluginType , plugin .Name (), time .Since (before ))
194
202
}
195
- return score
196
203
}
197
204
198
205
type defaultPlugin struct {
199
206
picker.RandomPicker
200
- << << << < HEAD
201
- == == == =
202
- metrics.RecordSchedulerPluginProcessingLatency (types.PostSchedulePluginType , plugin.Name (), time.Since (before ))
203
- }
204
- }
205
- == == == =
206
- >> >> >> > b24f948 (scheduler refactoring (#730 ))
207
207
}
208
208
209
- << << << < HEAD
210
- func (p * defaultPlugin ) Filter (ctx * types.SchedulingContext , pods []types.Pod ) []types.Pod {
211
- if ctx .Req .Critical {
212
- return lowLatencyFilter .Filter (ctx , pods )
213
- }
209
+ func (p * defaultPlugin ) Name () string {
210
+ return "DefaultPlugin"
211
+ }
214
212
215
- << << << < HEAD
216
- return sheddableRequestFilter .Filter (ctx , pods )
217
- == == == =
218
- func (p * defaultPlugin ) Filter (ctx * types .Context , pods []types .Pod ) ([]types.Pod , error ) {
219
- req := ctx .Req
220
- var filter types.Filter
221
- if req .Critical {
222
- filter = lowLatencyFilter
223
- } else {
224
- filter = sheddableRequestFilter
225
- }
226
- return filter .Filter (ctx , pods )
227
- >> >> >> > 45209 f6 (Refactor scheduler to run plugins (#677 ))
228
- == == == =
229
213
func (p * defaultPlugin ) Filter (ctx * types.SchedulingContext , pods []types.Pod ) []types.Pod {
230
214
if ctx .Req .Critical {
231
215
return lowLatencyFilter .Filter (ctx , pods )
232
216
}
233
217
234
218
return sheddableRequestFilter .Filter (ctx , pods )
235
- >> >> >> > b24f948 (scheduler refactoring (#730 ))
236
219
}
0 commit comments