Skip to content

Commit d167e49

Browse files
nirrozenbaumrlakhtakia
authored andcommitted
scheduler refactoring (kubernetes-sigs#730)
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 7792676 commit d167e49

File tree

10 files changed

+214
-285
lines changed

10 files changed

+214
-285
lines changed

pkg/epp/backend/metrics/pod_metrics.go

+5-6
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ type podMetrics struct {
4141
ds Datastore
4242
interval time.Duration
4343

44-
parentCtx context.Context
45-
once sync.Once // ensure the StartRefreshLoop is only called once.
46-
done chan struct{}
44+
once sync.Once // ensure the StartRefreshLoop is only called once.
45+
done chan struct{}
4746

4847
logger logr.Logger
4948
}
@@ -79,8 +78,8 @@ func toInternalPod(in *corev1.Pod) *Pod {
7978
}
8079

8180
// start starts a goroutine exactly once to periodically update metrics. The goroutine will be
82-
// stopped either when stop() is called, or the parentCtx is cancelled.
83-
func (pm *podMetrics) startRefreshLoop() {
81+
// stopped either when stop() is called, or the given ctx is cancelled.
82+
func (pm *podMetrics) startRefreshLoop(ctx context.Context) {
8483
pm.once.Do(func() {
8584
go func() {
8685
pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod())
@@ -90,7 +89,7 @@ func (pm *podMetrics) startRefreshLoop() {
9089
select {
9190
case <-pm.done:
9291
return
93-
case <-pm.parentCtx.Done():
92+
case <-ctx.Done():
9493
return
9594
case <-ticker.C: // refresh metrics periodically
9695
if err := pm.refreshMetrics(); err != nil {

pkg/epp/backend/metrics/types.go

+7-8
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,17 @@ type PodMetricsFactory struct {
4343
func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics {
4444
pod := toInternalPod(in)
4545
pm := &podMetrics{
46-
pmc: f.pmc,
47-
ds: ds,
48-
interval: f.refreshMetricsInterval,
49-
parentCtx: parentCtx,
50-
once: sync.Once{},
51-
done: make(chan struct{}),
52-
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
46+
pmc: f.pmc,
47+
ds: ds,
48+
interval: f.refreshMetricsInterval,
49+
once: sync.Once{},
50+
done: make(chan struct{}),
51+
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
5352
}
5453
pm.pod.Store(pod)
5554
pm.metrics.Store(newMetrics())
5655

57-
pm.startRefreshLoop()
56+
pm.startRefreshLoop(parentCtx)
5857
return pm
5958
}
6059

pkg/epp/scheduling/plugins/filter.go renamed to pkg/epp/scheduling/plugins/filter/filter.go

+36-49
Original file line numberDiff line numberDiff line change
@@ -14,56 +14,55 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package plugins
17+
package filter
1818

1919
import (
20-
"errors"
2120
"math"
2221
"math/rand"
2322
"time"
2423

2524
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
2626
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
27-
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
2827
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2928
)
3029

31-
type Filter struct {
30+
type baseFilter struct {
3231
name string
3332
filter filterFunc
3433
}
3534

36-
func (bf *Filter) Name() string {
37-
if bf == nil {
35+
func (f *baseFilter) Name() string {
36+
if f == nil {
3837
return "nil"
3938
}
40-
return bf.name
39+
return f.name
4140
}
4241

43-
func (bf *Filter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
42+
func (f *baseFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
4443
loggerTrace := ctx.Logger.V(logutil.TRACE)
45-
loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods))
44+
loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods))
4645

47-
return bf.filter(ctx, pods)
46+
return f.filter(ctx, pods)
4847
}
4948

5049
// DecisionTreeFilter applies current filterFunc, and then recursively applies next filters
5150
// depending success or failure of the current filter.
5251
// It can be used to construct a flow chart algorithm.
5352
type DecisionTreeFilter struct {
54-
Current types.Filter
53+
Current plugins.Filter
5554
// NextOnSuccess filter will be applied after successfully applying the current filter.
5655
// The filtered results will be passed to the next filter.
57-
NextOnSuccess types.Filter
56+
NextOnSuccess plugins.Filter
5857
// NextOnFailure filter will be applied if current filter fails.
5958
// The original input will be passed to the next filter.
60-
NextOnFailure types.Filter
59+
NextOnFailure plugins.Filter
6160
// NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
6261
// success or failure of the current filter.
6362
// NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
6463
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
6564
// NextOnSuccessOrFailure, in the success and failure scenarios, respectively.
66-
NextOnSuccessOrFailure types.Filter
65+
NextOnSuccessOrFailure plugins.Filter
6766
}
6867

6968
func (f *DecisionTreeFilter) Name() string {
@@ -73,15 +72,15 @@ func (f *DecisionTreeFilter) Name() string {
7372
return f.Current.Name()
7473
}
7574

76-
func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
75+
func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
7776
loggerTrace := ctx.Logger.V(logutil.TRACE)
78-
filtered, err := f.Current.Filter(ctx, pods)
77+
filtered := f.Current.Filter(ctx, pods)
7978

8079
next := f.NextOnSuccessOrFailure
81-
if err == nil && len(filtered) > 0 {
80+
if len(filtered) > 0 {
8281
if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil {
8382
// No succeeding filters to run, return.
84-
return filtered, err
83+
return filtered
8584
}
8685
if f.NextOnSuccess != nil {
8786
next = f.NextOnSuccess
@@ -92,7 +91,7 @@ func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]typ
9291
} else {
9392
if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil {
9493
// No succeeding filters to run, return.
95-
return filtered, err
94+
return filtered
9695
}
9796
if f.NextOnFailure != nil {
9897
next = f.NextOnFailure
@@ -104,26 +103,24 @@ func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]typ
104103
}
105104

106105
// filterFunc filters a set of input pods to a subset.
107-
type filterFunc func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error)
106+
type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod
108107

109108
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
110109
func toFilterFunc(pp podPredicate) filterFunc {
111-
return func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
110+
return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
112111
filtered := []types.Pod{}
113112
for _, pod := range pods {
114113
pass := pp(ctx.Req, pod)
115114
if pass {
116115
filtered = append(filtered, pod)
117116
}
118117
}
119-
if len(filtered) == 0 {
120-
return nil, errors.New("no pods left")
121-
}
122-
return filtered, nil
118+
119+
return filtered
123120
}
124121
}
125122

126-
var LeastQueueFilter = &Filter{
123+
var LeastQueueFilter = &baseFilter{
127124
name: "least queuing",
128125
filter: leastQueuingFilterFunc,
129126
}
@@ -135,7 +132,7 @@ var LeastQueueFilter = &Filter{
135132
// the least one as it gives more choices for the next filter, which on aggregate gave better
136133
// results.
137134
// TODO: Compare this strategy with other strategies such as top K.
138-
func leastQueuingFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
135+
func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
139136
min := math.MaxInt
140137
max := 0
141138
filtered := []types.Pod{}
@@ -154,15 +151,15 @@ func leastQueuingFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod,
154151
filtered = append(filtered, pod)
155152
}
156153
}
157-
return filtered, nil
154+
return filtered
158155
}
159156

160-
var LowQueueFilter = &Filter{
157+
var LowQueueFilter = &baseFilter{
161158
name: "low queueing filter",
162159
filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))),
163160
}
164161

165-
var LeastKVCacheFilter = &Filter{
162+
var LeastKVCacheFilter = &baseFilter{
166163
name: "least KV cache percent",
167164
filter: leastKVCacheFilterFunc,
168165
}
@@ -173,7 +170,7 @@ var LeastKVCacheFilter = &Filter{
173170
// should consider them all instead of the absolute minimum one. This worked better than picking the
174171
// least one as it gives more choices for the next filter, which on aggregate gave better results.
175172
// TODO: Compare this strategy with other strategies such as top K.
176-
func leastKVCacheFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
173+
func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
177174
min := math.MaxFloat64
178175
var max float64 = 0
179176
filtered := []types.Pod{}
@@ -192,10 +189,10 @@ func leastKVCacheFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod,
192189
filtered = append(filtered, pod)
193190
}
194191
}
195-
return filtered, nil
192+
return filtered
196193
}
197194

198-
var LoRAAffinityFilter = &Filter{
195+
var LoRAAffinityFilter = &baseFilter{
199196
name: "affinity LoRA",
200197
filter: loRASoftAffinityFilterFunc,
201198
}
@@ -216,7 +213,7 @@ var LoRAAffinityFilter = &Filter{
216213
// Returns:
217214
// - Filtered slice of pod metrics based on affinity and availability
218215
// - Error if any issues occur during filtering
219-
func loRASoftAffinityFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
216+
func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
220217

221218
// Pre-allocate slices with estimated capacity
222219
filtered_affinity := make([]types.Pod, 0, len(pods))
@@ -241,34 +238,24 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.P
241238
// If both groups have pods, use probability to select which group to return
242239
if len(filtered_affinity) > 0 && len(filtered_available) > 0 {
243240
if randGen.Float64() < config.Conf.LoraAffinityThreshold {
244-
return filtered_affinity, nil
241+
return filtered_affinity
245242
}
246-
return filtered_available, nil
243+
return filtered_available
247244
}
248245

249246
// Return whichever group has pods
250247
if len(filtered_affinity) > 0 {
251-
return filtered_affinity, nil
248+
return filtered_affinity
252249
}
253250

254-
return filtered_available, nil
251+
return filtered_available
255252
}
256253

257-
var HasCapacityFilter = &Filter{
254+
var HasCapacityFilter = &baseFilter{
258255
name: "has capacity for sheddable requests",
259256
filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))),
260257
}
261258

262-
var DropRequestFilter = &Filter{
263-
name: "drop request",
264-
filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
265-
ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req)
266-
return []types.Pod{}, errutil.Error{
267-
Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources",
268-
}
269-
},
270-
}
271-
272259
// podPredicate is a filter function to check whether a pod is desired.
273260
type podPredicate func(req *types.LLMRequest, pod types.Pod) bool
274261

pkg/epp/scheduling/plugins/filter_test.go renamed to pkg/epp/scheduling/plugins/filter/filter_test.go

+13-25
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package plugins
17+
package filter
1818

1919
import (
2020
"context"
21-
"errors"
2221
"testing"
2322

2423
"github.com/google/go-cmp/cmp"
@@ -34,30 +33,26 @@ func TestFilter(t *testing.T) {
3433
req *types.LLMRequest
3534
input []types.Pod
3635
output []types.Pod
37-
err bool
3836
filter *DecisionTreeFilter
3937
}{
4038
{
41-
name: "simple filter without successor, failure",
39+
name: "simple filter without available pods",
4240
filter: &DecisionTreeFilter{
43-
Current: &Filter{
44-
name: "error",
45-
filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) {
46-
return nil, errors.New("filter error")
41+
Current: &baseFilter{
42+
name: "filter all",
43+
filter: func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
44+
return []types.Pod{}
4745
},
4846
},
4947
},
50-
err: true,
48+
output: []types.Pod{},
5149
},
5250
}
5351

5452
for _, test := range tests {
5553
t.Run(test.name, func(t *testing.T) {
56-
ctx := types.NewContext(context.Background(), test.req, test.input)
57-
got, err := test.filter.Filter(ctx, test.input)
58-
if test.err != (err != nil) {
59-
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
60-
}
54+
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
55+
got := test.filter.Filter(ctx, test.input)
6156

6257
opt := cmp.AllowUnexported(types.PodMetrics{})
6358
if diff := cmp.Diff(test.output, got, opt); diff != "" {
@@ -74,7 +69,6 @@ func TestFilterFunc(t *testing.T) {
7469
req *types.LLMRequest
7570
input []types.Pod
7671
output []types.Pod
77-
err bool
7872
}{
7973
{
8074
name: "least queuing empty input",
@@ -193,11 +187,8 @@ func TestFilterFunc(t *testing.T) {
193187

194188
for _, test := range tests {
195189
t.Run(test.name, func(t *testing.T) {
196-
ctx := types.NewContext(context.Background(), test.req, test.input)
197-
got, err := test.f(ctx, test.input)
198-
if test.err != (err != nil) {
199-
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
200-
}
190+
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
191+
got := test.f(ctx, test.input)
201192

202193
opt := cmp.AllowUnexported(types.PodMetrics{})
203194
if diff := cmp.Diff(test.output, got, opt); diff != "" {
@@ -254,7 +245,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
254245
},
255246
},
256247
}
257-
ctx := types.NewContext(context.Background(), req, pods)
248+
ctx := types.NewSchedulingContext(context.Background(), req, pods)
258249

259250
// Run the filter function multiple times and count the results
260251
affinityCount := 0
@@ -265,10 +256,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
265256
expectedAvailabilityPercent := 100 - expectedAffinityPercent
266257

267258
for i := 0; i < numIterations; i++ {
268-
result, err := loRASoftAffinityFilterFunc(ctx, pods)
269-
if err != nil {
270-
t.Fatalf("Unexpected error: %v", err)
271-
}
259+
result := loRASoftAffinityFilterFunc(ctx, pods)
272260

273261
// Check which type of pod was returned
274262
if len(result) != 1 {

0 commit comments

Comments
 (0)