Skip to content

Commit 7de96f5

Browse files
committed
configurable filter chains
Signed-off-by: Kuromesi <[email protected]>
1 parent 38cddf0 commit 7de96f5

File tree

11 files changed

+804
-41
lines changed

11 files changed

+804
-41
lines changed

pkg/ext-proc/backend/datastore.go

+12
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,22 @@ type K8sDatastore struct {
2929
inferencePool *v1alpha1.InferencePool
3030
InferenceModels *sync.Map
3131
pods *sync.Map
32+
33+
filterConfigMap *corev1.ConfigMap
3234
}
3335

3436
type K8sDatastoreOption func(*K8sDatastore)
3537

38+
func (ds *K8sDatastore) GetFilterConfigMap() *corev1.ConfigMap {
39+
return ds.filterConfigMap
40+
}
41+
42+
func WithFilterConfigMap(filterConfigMap *corev1.ConfigMap) K8sDatastoreOption {
43+
return func(store *K8sDatastore) {
44+
store.filterConfigMap = filterConfigMap
45+
}
46+
}
47+
3648
// WithPods can be used in tests to override the pods.
3749
func WithPods(pods []*PodMetrics) K8sDatastoreOption {
3850
return func(store *K8sDatastore) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package backend
2+
3+
import (
4+
"context"
5+
6+
corev1 "k8s.io/api/core/v1"
7+
"k8s.io/klog/v2"
8+
ctrl "sigs.k8s.io/controller-runtime"
9+
"sigs.k8s.io/controller-runtime/pkg/client"
10+
)
11+
12+
type FilterConfigReconciler struct {
13+
client.Client
14+
Datastore *K8sDatastore
15+
}
16+
17+
func (c *FilterConfigReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
18+
if req.NamespacedName.Name != "filter-config" || req.NamespacedName.Namespace != "default" {
19+
return ctrl.Result{}, nil
20+
}
21+
cm := &corev1.ConfigMap{}
22+
if err := c.Get(ctx, req.NamespacedName, cm); err != nil {
23+
klog.Errorf("unable to get ConfigMap, err: %v", err)
24+
return ctrl.Result{}, err
25+
}
26+
klog.Infof("updating filter config to: %++v", cm.Data)
27+
c.Datastore.filterConfigMap = cm.DeepCopy()
28+
return ctrl.Result{}, nil
29+
}
30+
31+
func (c *FilterConfigReconciler) SetupWithManager(mgr ctrl.Manager) error {
32+
return ctrl.NewControllerManagedBy(mgr).
33+
For(&corev1.ConfigMap{}).
34+
Complete(c)
35+
}

pkg/ext-proc/main.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ func main() {
146146
klog.Error(err, "Error setting up EndpointSliceReconciler")
147147
}
148148

149+
if err := (&backend.FilterConfigReconciler{
150+
Datastore: datastore,
151+
Client: mgr.GetClient(),
152+
}).SetupWithManager(mgr); err != nil {
153+
klog.Error(err, "Error setting up EndpointSliceReconciler")
154+
}
155+
149156
errChan := make(chan error)
150157
go func() {
151158
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
@@ -160,11 +167,14 @@ func main() {
160167
if err := pp.Init(*refreshPodsInterval, *refreshMetricsInterval); err != nil {
161168
klog.Fatalf("failed to initialize: %v", err)
162169
}
170+
171+
orchestrator := scheduling.NewFilterOrchestrator(datastore)
172+
163173
extProcPb.RegisterExternalProcessorServer(
164174
s,
165175
handlers.NewServer(
166176
pp,
167-
scheduling.NewScheduler(pp),
177+
scheduling.NewScheduler(pp, scheduling.WithOrchestrator(orchestrator)),
168178
*targetPodHeader,
169179
datastore))
170180
healthPb.RegisterHealthServer(s, &healthServer{})
+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package algorithms
2+
3+
import "container/heap"
4+
5+
type TopK[T any] interface {
6+
TopK(elems []T, k int) []T
7+
}
8+
9+
type HeapTopKImpl[T any] struct {
10+
cmp func(a, b T) bool
11+
sorted []T
12+
}
13+
14+
func NewHeapTopK[T any](cmp func(a, b T) bool) TopK[T] {
15+
return &HeapTopKImpl[T]{
16+
cmp: cmp,
17+
}
18+
}
19+
20+
func (h *HeapTopKImpl[T]) TopK(elems []T, k int) []T {
21+
if k <= 0 {
22+
return []T{}
23+
}
24+
25+
if k >= len(elems) {
26+
return elems
27+
}
28+
29+
h.sorted = []T{}
30+
heap.Init(h)
31+
for _, e := range elems {
32+
heap.Push(h, e)
33+
if h.Len() > k {
34+
heap.Pop(h)
35+
}
36+
}
37+
return h.sorted
38+
}
39+
40+
func (h *HeapTopKImpl[T]) Len() int { return len(h.sorted) }
41+
func (h *HeapTopKImpl[T]) Less(i, j int) bool { return h.cmp(h.sorted[i], h.sorted[j]) }
42+
func (h *HeapTopKImpl[T]) Swap(i, j int) { h.sorted[i], h.sorted[j] = h.sorted[j], h.sorted[i] }
43+
44+
func (h *HeapTopKImpl[T]) Push(x any) {
45+
h.sorted = append(h.sorted, x.(T))
46+
}
47+
48+
func (h *HeapTopKImpl[T]) Pop() any {
49+
pop := h.sorted[len(h.sorted)-1]
50+
h.sorted = h.sorted[:len(h.sorted)-1]
51+
return pop
52+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package algorithms
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
func TestMaxTopK(t *testing.T) {
9+
tests := []struct {
10+
elems []int
11+
k int
12+
want []int
13+
}{
14+
{[]int{1, 2, 3, 4, 5}, 3, []int{3, 4, 5}},
15+
{[]int{5, 4, 3, 2, 1}, 2, []int{4, 5}},
16+
{[]int{1, 3, 5, 7, 9}, 0, []int{}},
17+
{[]int{}, 3, []int{}},
18+
{[]int{10}, 1, []int{10}},
19+
{[]int{1, 2, 3}, 5, []int{1, 2, 3}},
20+
}
21+
22+
for _, tt := range tests {
23+
h := &HeapTopKImpl[int]{cmp: func(a, b int) bool { return a < b }}
24+
got := h.TopK(tt.elems, tt.k)
25+
if !reflect.DeepEqual(got, tt.want) {
26+
t.Errorf("TopK(%v, %d) = %v, want %v", tt.elems, tt.k, got, tt.want)
27+
}
28+
}
29+
}
30+
31+
func TestMinTopK(t *testing.T) {
32+
tests := []struct {
33+
elems []int
34+
k int
35+
want []int
36+
}{
37+
{[]int{1, 2, 3, 4, 5}, 3, []int{3, 1, 2}},
38+
{[]int{5, 4, 3, 2, 1}, 2, []int{2, 1}},
39+
{[]int{1, 2, 3}, 5, []int{1, 2, 3}},
40+
}
41+
42+
for _, tt := range tests {
43+
h := &HeapTopKImpl[int]{cmp: func(a, b int) bool { return a > b }}
44+
got := h.TopK(tt.elems, tt.k)
45+
if !reflect.DeepEqual(got, tt.want) {
46+
t.Errorf("TopK(%v, %d) = %v, want %v", tt.elems, tt.k, got, tt.want)
47+
}
48+
}
49+
}

pkg/ext-proc/scheduling/filter.go

+22-13
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,46 @@ import (
44
"errors"
55
"math"
66

7+
"google.golang.org/grpc/codes"
8+
"google.golang.org/grpc/status"
79
"inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend"
10+
811
klog "k8s.io/klog/v2"
912
)
1013

11-
type Filter interface {
14+
type FilterChain interface {
1215
Name() string
1316
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
1417
}
1518

16-
// filter applies current filterFunc, and then recursively applies next filters depending success or
19+
// filterChainImpl applies current filterFunc, and then recursively applies next filters depending success or
1720
// failure of the current filterFunc.
1821
// It can be used to construct a flow chart algorithm.
19-
type filter struct {
22+
type filterChainImpl struct {
2023
name string
21-
filter filterFunc
24+
filter filter
2225
// nextOnSuccess filter will be applied after successfully applying the current filter.
2326
// The filtered results will be passed to the next filter.
24-
nextOnSuccess *filter
27+
nextOnSuccess *filterChainImpl
2528
// nextOnFailure filter will be applied if current filter fails.
2629
// The original input will be passed to the next filter.
27-
nextOnFailure *filter
30+
nextOnFailure *filterChainImpl
2831
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
2932
// success or failure of the current filter.
3033
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
3134
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
3235
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
33-
nextOnSuccessOrFailure *filter
36+
nextOnSuccessOrFailure *filterChainImpl
3437
}
3538

36-
func (f *filter) Name() string {
39+
func (f *filterChainImpl) Name() string {
3740
if f == nil {
3841
return "nil"
3942
}
4043
return f.name
4144
}
4245

43-
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
46+
func (f *filterChainImpl) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
4447
klog.V(3).Infof("Running filter %q on request %v with %v pods", f.name, req, len(pods))
4548

4649
filtered, err := f.filter(req, pods)
@@ -71,11 +74,11 @@ func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend
7174
}
7275
}
7376

74-
// filterFunc filters a set of input pods to a subset.
75-
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
77+
// filter filters a set of input pods to a subset.
78+
type filter func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
7679

77-
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
78-
func toFilterFunc(pp podPredicate) filterFunc {
80+
// toFilter is a helper function to convert a per pod filter func to the FilterFunc.
81+
func toFilter(pp podPredicate) filter {
7982
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
8083
filtered := []*backend.PodMetrics{}
8184
for _, pod := range pods {
@@ -152,6 +155,12 @@ func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*bac
152155
return filtered, nil
153156
}
154157

158+
func dropRequestFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
159+
klog.Infof("Dropping request %v", req)
160+
return []*backend.PodMetrics{}, status.Errorf(
161+
codes.ResourceExhausted, "dropping request due to limited backend resources")
162+
}
163+
155164
// podPredicate is a filter function to check whether a pod is desired.
156165
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
157166

pkg/ext-proc/scheduling/filter_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ func TestFilter(t *testing.T) {
1515
input []*backend.PodMetrics
1616
output []*backend.PodMetrics
1717
err bool
18-
filter *filter
18+
filter *filterChainImpl
1919
}{
2020
{
2121
name: "simple filter without successor, failure",
22-
filter: &filter{filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
22+
filter: &filterChainImpl{filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
2323
return nil, errors.New("filter error")
2424
}},
2525
err: true,
@@ -216,7 +216,7 @@ func TestFilter(t *testing.T) {
216216
func TestFilterFunc(t *testing.T) {
217217
tests := []struct {
218218
name string
219-
f filterFunc
219+
f filter
220220
req *LLMRequest
221221
input []*backend.PodMetrics
222222
output []*backend.PodMetrics
@@ -302,7 +302,7 @@ func TestFilterFunc(t *testing.T) {
302302
},
303303
{
304304
name: "noQueueAndLessThanKVCacheThresholdPredicate",
305-
f: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)),
305+
f: toFilter(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)),
306306
input: []*backend.PodMetrics{
307307
{
308308
// This pod should be returned.
@@ -337,7 +337,7 @@ func TestFilterFunc(t *testing.T) {
337337
},
338338
{
339339
name: "low LoRA cost",
340-
f: toFilterFunc(lowLoRACostPredicate),
340+
f: toFilter(lowLoRACostPredicate),
341341
req: &LLMRequest{
342342
Model: "model",
343343
ResolvedTargetModel: "model",

0 commit comments

Comments
 (0)