@@ -220,9 +220,17 @@ func TestSchedule(t *testing.T) {
220
220
},
221
221
}
222
222
223
+ schedConfig := & SchedulerConfig {
224
+ preSchedulePlugins : []plugins.PreSchedule {},
225
+ scorers : []plugins.Scorer {},
226
+ filters : []plugins.Filter {defPlugin },
227
+ postSchedulePlugins : []plugins.PostSchedule {},
228
+ picker : defPlugin ,
229
+ }
230
+
223
231
for _ , test := range tests {
224
232
t .Run (test .name , func (t * testing.T ) {
225
- scheduler := NewScheduler (& fakeDataStore {pods : test .input })
233
+ scheduler := NewSchedulerWithConfig (& fakeDataStore {pods : test .input }, schedConfig )
226
234
got , err := scheduler .Schedule (context .Background (), test .req )
227
235
if test .err != (err != nil ) {
228
236
t .Errorf ("Unexpected error, got %v, want %v" , err , test .err )
@@ -257,26 +265,24 @@ func TestSchedulePlugins(t *testing.T) {
257
265
}
258
266
259
267
tests := []struct {
260
- name string
261
- preSchedulePlugins []plugins.PreSchedule
262
- filters []plugins.Filter
263
- scorers []plugins.Scorer
264
- postSchedulePlugins []plugins.PostSchedule
265
- picker plugins.Picker
266
- input []* backendmetrics.FakePodMetrics
267
- wantTargetPod k8stypes.NamespacedName
268
- targetPodScore float64
268
+ name string
269
+ config SchedulerConfig
270
+ input []* backendmetrics.FakePodMetrics
271
+ wantTargetPod k8stypes.NamespacedName
272
+ targetPodScore float64
269
273
// Number of expected pods to score (after filter)
270
274
numPodsToScore int
271
275
err bool
272
276
}{
273
277
{
274
- name : "all plugins executed successfully" ,
275
- preSchedulePlugins : []plugins.PreSchedule {tp1 , tp2 },
276
- filters : []plugins.Filter {tp1 , tp2 },
277
- scorers : []plugins.Scorer {tp1 , tp2 },
278
- postSchedulePlugins : []plugins.PostSchedule {tp1 , tp2 },
279
- picker : pickerPlugin ,
278
+ name : "all plugins executed successfully" ,
279
+ config : SchedulerConfig {
280
+ preSchedulePlugins : []plugins.PreSchedule {tp1 , tp2 },
281
+ filters : []plugins.Filter {tp1 , tp2 },
282
+ scorers : []plugins.Scorer {tp1 , tp2 },
283
+ postSchedulePlugins : []plugins.PostSchedule {tp1 , tp2 },
284
+ picker : pickerPlugin ,
285
+ },
280
286
input : []* backendmetrics.FakePodMetrics {
281
287
{Pod : & backendmetrics.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod1" }}},
282
288
{Pod : & backendmetrics.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod2" }}},
@@ -288,12 +294,14 @@ func TestSchedulePlugins(t *testing.T) {
288
294
err : false ,
289
295
},
290
296
{
291
- name : "filter all" ,
292
- preSchedulePlugins : []plugins.PreSchedule {tp1 , tp2 },
293
- filters : []plugins.Filter {tp1 , tp_filterAll },
294
- scorers : []plugins.Scorer {tp1 , tp2 },
295
- postSchedulePlugins : []plugins.PostSchedule {tp1 , tp2 },
296
- picker : pickerPlugin ,
297
+ name : "filter all" ,
298
+ config : SchedulerConfig {
299
+ preSchedulePlugins : []plugins.PreSchedule {tp1 , tp2 },
300
+ filters : []plugins.Filter {tp1 , tp_filterAll },
301
+ scorers : []plugins.Scorer {tp1 , tp2 },
302
+ postSchedulePlugins : []plugins.PostSchedule {tp1 , tp2 },
303
+ picker : pickerPlugin ,
304
+ },
297
305
input : []* backendmetrics.FakePodMetrics {
298
306
{Pod : & backendmetrics.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod1" }}},
299
307
{Pod : & backendmetrics.Pod {NamespacedName : k8stypes.NamespacedName {Name : "pod2" }}},
@@ -307,29 +315,22 @@ func TestSchedulePlugins(t *testing.T) {
307
315
for _ , test := range tests {
308
316
t .Run (test .name , func (t * testing.T ) {
309
317
// Reset all plugins before each new test case.
310
- for _ , plugin := range test .preSchedulePlugins {
318
+ for _ , plugin := range test .config . preSchedulePlugins {
311
319
plugin .(* TestPlugin ).reset ()
312
320
}
313
- for _ , plugin := range test .postSchedulePlugins {
321
+ for _ , plugin := range test .config . postSchedulePlugins {
314
322
plugin .(* TestPlugin ).reset ()
315
323
}
316
- for _ , plugin := range test .filters {
324
+ for _ , plugin := range test .config . filters {
317
325
plugin .(* TestPlugin ).reset ()
318
326
}
319
- for _ , plugin := range test .scorers {
327
+ for _ , plugin := range test .config . scorers {
320
328
plugin .(* TestPlugin ).reset ()
321
329
}
322
- test .picker .(* TestPlugin ).reset ()
330
+ test .config . picker .(* TestPlugin ).reset ()
323
331
324
332
// Initialize the scheduler
325
- scheduler := & Scheduler {
326
- datastore : & fakeDataStore {pods : test .input },
327
- preSchedulePlugins : test .preSchedulePlugins ,
328
- filters : test .filters ,
329
- scorers : test .scorers ,
330
- postSchedulePlugins : test .postSchedulePlugins ,
331
- picker : test .picker ,
332
- }
333
+ scheduler := NewSchedulerWithConfig (& fakeDataStore {pods : test .input }, & test .config )
333
334
334
335
req := & types.LLMRequest {Model : "test-model" }
335
336
got , err := scheduler .Schedule (context .Background (), req )
@@ -355,35 +356,35 @@ func TestSchedulePlugins(t *testing.T) {
355
356
}
356
357
357
358
// Validate plugin execution counts dynamically
358
- for _ , plugin := range test .preSchedulePlugins {
359
+ for _ , plugin := range test .config . preSchedulePlugins {
359
360
tp , _ := plugin .(* TestPlugin )
360
361
if tp .PreScheduleCallCount != 1 {
361
362
t .Errorf ("Plugin %s PreSchedule() called %d times, expected 1" , tp .NameRes , tp .PreScheduleCallCount )
362
363
}
363
364
}
364
365
365
- for _ , plugin := range test .filters {
366
+ for _ , plugin := range test .config . filters {
366
367
tp , _ := plugin .(* TestPlugin )
367
368
if tp .FilterCallCount != 1 {
368
369
t .Errorf ("Plugin %s Filter() called %d times, expected 1" , tp .NameRes , tp .FilterCallCount )
369
370
}
370
371
}
371
372
372
- for _ , plugin := range test .scorers {
373
+ for _ , plugin := range test .config . scorers {
373
374
tp , _ := plugin .(* TestPlugin )
374
375
if tp .ScoreCallCount != test .numPodsToScore {
375
376
t .Errorf ("Plugin %s Score() called %d times, expected 1" , tp .NameRes , tp .ScoreCallCount )
376
377
}
377
378
}
378
379
379
- for _ , plugin := range test .postSchedulePlugins {
380
+ for _ , plugin := range test .config . postSchedulePlugins {
380
381
tp , _ := plugin .(* TestPlugin )
381
382
if tp .PostScheduleCallCount != 1 {
382
383
t .Errorf ("Plugin %s PostSchedule() called %d times, expected 1" , tp .NameRes , tp .PostScheduleCallCount )
383
384
}
384
385
}
385
386
386
- tp , _ := test .picker .(* TestPlugin )
387
+ tp , _ := test .config . picker .(* TestPlugin )
387
388
if tp .PickCallCount != 1 {
388
389
t .Errorf ("Picker plugin %s Pick() called %d times, expected 1" , tp .NameRes , tp .PickCallCount )
389
390
}
0 commit comments