diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index c85d4d794..7339389ad 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -41,9 +41,8 @@ type podMetrics struct { ds Datastore interval time.Duration - parentCtx context.Context - once sync.Once // ensure the StartRefreshLoop is only called once. - done chan struct{} + once sync.Once // ensure the StartRefreshLoop is only called once. + done chan struct{} logger logr.Logger } @@ -79,8 +78,8 @@ func toInternalPod(in *corev1.Pod) *Pod { } // start starts a goroutine exactly once to periodically update metrics. The goroutine will be -// stopped either when stop() is called, or the parentCtx is cancelled. -func (pm *podMetrics) startRefreshLoop() { +// stopped either when stop() is called, or the given ctx is cancelled. +func (pm *podMetrics) startRefreshLoop(ctx context.Context) { pm.once.Do(func() { go func() { pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod()) @@ -90,7 +89,7 @@ func (pm *podMetrics) startRefreshLoop() { select { case <-pm.done: return - case <-pm.parentCtx.Done(): + case <-ctx.Done(): return case <-ticker.C: // refresh metrics periodically if err := pm.refreshMetrics(); err != nil { diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index 925a0cc5a..20d42ae4b 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -43,18 +43,17 @@ type PodMetricsFactory struct { func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics { pod := toInternalPod(in) pm := &podMetrics{ - pmc: f.pmc, - ds: ds, - interval: f.refreshMetricsInterval, - parentCtx: parentCtx, - once: sync.Once{}, - done: make(chan struct{}), - logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName), + pmc: f.pmc, + ds: ds, + interval: f.refreshMetricsInterval, + once: sync.Once{}, + done: make(chan struct{}), + logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName), } pm.pod.Store(pod) pm.metrics.Store(newMetrics()) - pm.startRefreshLoop() + pm.startRefreshLoop(parentCtx) return pm } @@ -79,6 +78,10 @@ func (p *Pod) String() string { } func (p *Pod) Clone() *Pod { + if p == nil { + return nil + } + return &Pod{ NamespacedName: types.NamespacedName{ Name: p.NamespacedName.Name, @@ -118,6 +121,10 @@ func (m *Metrics) String() string { } func (m *Metrics) Clone() *Metrics { + if m == nil { + return nil + } + cm := make(map[string]int, len(m.ActiveModels)) for k, v := range m.ActiveModels { cm[k] = v diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 44537923d..9121b59af 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -67,7 +67,7 @@ func (s *StreamingServer) HandleRequestBody( ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, } - logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical) + logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) var err error // Update target models in the body. @@ -81,11 +81,11 @@ func (s *StreamingServer) HandleRequestBody( return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} } - target, err := s.scheduler.Schedule(ctx, llmReq) + res, err := s.scheduler.Schedule(ctx, llmReq) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } - targetPod := target.GetPod() + targetPod := res.TargetPod.GetPod() // Insert target endpoint to instruct Envoy to route requests to the specified target pod. // Attach the port number @@ -96,8 +96,7 @@ func (s *StreamingServer) HandleRequestBody( endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) logger.V(logutil.DEFAULT).Info("Request handled", - "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics", - fmt.Sprintf("%+v", target)) + "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 7bb0fcb16..e7ecc26dc 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -65,7 +65,7 @@ type StreamingServer struct { } type Scheduler interface { - Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error) + Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) } // RequestContext stores context information during the life time of an HTTP request. diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/filter.go index 99044e976..a9fd2f969 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/filter.go @@ -17,20 +17,15 @@ limitations under the License. package scheduling import ( - "errors" "math" "math/rand" "time" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -type Filter interface { - Name() string - Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) -} - type basicFilter struct { name string filter filterFunc @@ -43,7 +38,7 @@ func (bf *basicFilter) Name() string { return bf.name } -func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func (bf *basicFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { loggerTrace := ctx.Logger.V(logutil.TRACE) loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods)) @@ -54,19 +49,19 @@ func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]* // depending success or failure of the current filter. // It can be used to construct a flow chart algorithm. type decisionTreeFilter struct { - current Filter + current plugins.Filter // nextOnSuccess filter will be applied after successfully applying the current filter. // The filtered results will be passed to the next filter. - nextOnSuccess Filter + nextOnSuccess plugins.Filter // nextOnFailure filter will be applied if current filter fails. // The original input will be passed to the next filter. - nextOnFailure Filter + nextOnFailure plugins.Filter // nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the // success or failure of the current filter. // NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of // nextOnSuccessOrFailure, in the success and failure scenarios, respectively. - nextOnSuccessOrFailure Filter + nextOnSuccessOrFailure plugins.Filter } func (f *decisionTreeFilter) Name() string { @@ -76,15 +71,15 @@ func (f *decisionTreeFilter) Name() string { return f.current.Name() } -func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func (f *decisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { loggerTrace := ctx.Logger.V(logutil.TRACE) - filtered, err := f.current.Filter(ctx, pods) + filtered := f.current.Filter(ctx, pods) next := f.nextOnSuccessOrFailure - if err == nil && len(filtered) > 0 { + if len(filtered) > 0 { if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil { // No succeeding filters to run, return. - return filtered, err + return filtered } if f.nextOnSuccess != nil { next = f.nextOnSuccess @@ -95,7 +90,7 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics } else { if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil { // No succeeding filters to run, return. - return filtered, err + return filtered } if f.nextOnFailure != nil { next = f.nextOnFailure @@ -107,22 +102,20 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics } // filterFunc filters a set of input pods to a subset. -type filterFunc func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) +type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - filtered := []*types.PodMetrics{} + return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filtered := []types.Pod{} for _, pod := range pods { pass := pp(ctx.Req, pod) if pass { filtered = append(filtered, pod) } } - if len(filtered) == 0 { - return nil, errors.New("no pods left") - } - return filtered, nil + + return filtered } } @@ -138,26 +131,26 @@ var leastQueueFilter = &basicFilter{ // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { min := math.MaxInt max := 0 - filtered := []*types.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { - if pod.WaitingQueueSize <= min { - min = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize <= min { + min = pod.GetMetrics().WaitingQueueSize } - if pod.WaitingQueueSize >= max { - max = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize >= max { + max = pod.GetMetrics().WaitingQueueSize } } for _, pod := range pods { - if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) { + if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { filtered = append(filtered, pod) } } - return filtered, nil + return filtered } var lowQueueFilter = &basicFilter{ @@ -176,26 +169,26 @@ var leastKVCacheFilter = &basicFilter{ // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { min := math.MaxFloat64 var max float64 = 0 - filtered := []*types.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { - if pod.KVCacheUsagePercent <= min { - min = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent <= min { + min = pod.GetMetrics().KVCacheUsagePercent } - if pod.KVCacheUsagePercent >= max { - max = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent >= max { + max = pod.GetMetrics().KVCacheUsagePercent } } for _, pod := range pods { - if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { + if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { filtered = append(filtered, pod) } } - return filtered, nil + return filtered } var loRAAffinityFilter = &basicFilter{ @@ -219,20 +212,20 @@ var loRAAffinityFilter = &basicFilter{ // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { // Pre-allocate slices with estimated capacity - filtered_affinity := make([]*types.PodMetrics, 0, len(pods)) - filtered_available := make([]*types.PodMetrics, 0, len(pods)) + filtered_affinity := make([]types.Pod, 0, len(pods)) + filtered_available := make([]types.Pod, 0, len(pods)) // Categorize pods based on affinity and availability for _, pod := range pods { - _, active := pod.ActiveModels[ctx.Req.ResolvedTargetModel] - _, waiting := pod.WaitingModels[ctx.Req.ResolvedTargetModel] + _, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel] + _, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel] if active || waiting { filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.ActiveModels)+len(pod.WaitingModels) < pod.MaxActiveModels { + } else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels { filtered_available = append(filtered_available, pod) } } @@ -244,36 +237,36 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([ // If both groups have pods, use probability to select which group to return if len(filtered_affinity) > 0 && len(filtered_available) > 0 { if randGen.Float64() < config.LoraAffinityThreshold { - return filtered_affinity, nil + return filtered_affinity } - return filtered_available, nil + return filtered_available } // Return whichever group has pods if len(filtered_affinity) > 0 { - return filtered_affinity, nil + return filtered_affinity } - return filtered_available, nil + return filtered_available } // podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *types.LLMRequest, pod *types.PodMetrics) bool +type podPredicate func(req *types.LLMRequest, pod types.Pod) bool func queueThresholdPredicate(queueThreshold int) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { - return pod.WaitingQueueSize <= queueThreshold + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().WaitingQueueSize <= queueThreshold } } func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { - return pod.KVCacheUsagePercent <= kvCacheThreshold + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold } } func (pp podPredicate) and(another podPredicate) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { + return func(req *types.LLMRequest, pod types.Pod) bool { return pp(req, pod) && another(req, pod) } } diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/filter_test.go index 543826d06..cec3d4905 100644 --- a/pkg/epp/scheduling/filter_test.go +++ b/pkg/epp/scheduling/filter_test.go @@ -18,7 +18,6 @@ package scheduling import ( "context" - "errors" "testing" "github.com/google/go-cmp/cmp" @@ -31,32 +30,28 @@ func TestFilter(t *testing.T) { tests := []struct { name string req *types.LLMRequest - input []*types.PodMetrics - output []*types.PodMetrics - err bool + input []types.Pod + output []types.Pod filter *decisionTreeFilter }{ { name: "simple filter without successor, failure", filter: &decisionTreeFilter{ current: &basicFilter{ - name: "error", - filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - return nil, errors.New("filter error") + name: "filter all pods", + filter: func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + return []types.Pod{} }, }, }, - err: true, + output: []types.Pod{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewContext(context.Background(), test.req, test.input) - got, err := test.filter.Filter(ctx, test.input) - if test.err != (err != nil) { - t.Errorf("Unexpected error, got %v, want %v", err, test.err) - } + ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + got := test.filter.Filter(ctx, test.input) if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) @@ -70,43 +65,42 @@ func TestFilterFunc(t *testing.T) { name string f filterFunc req *types.LLMRequest - input []*types.PodMetrics - output []*types.PodMetrics - err bool + input []types.Pod + output []types.Pod }{ { name: "least queuing empty input", f: leastQueuingFilterFunc, - input: []*types.PodMetrics{}, - output: []*types.PodMetrics{}, + input: []types.Pod{}, + output: []types.Pod{}, }, { name: "least queuing", f: leastQueuingFilterFunc, - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, @@ -116,36 +110,36 @@ func TestFilterFunc(t *testing.T) { { name: "least kv cache empty input", f: leastKVCacheFilterFunc, - input: []*types.PodMetrics{}, - output: []*types.PodMetrics{}, + input: []types.Pod{}, + output: []types.Pod{}, }, { name: "least kv cache", f: leastKVCacheFilterFunc, - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 1.0, }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, @@ -155,22 +149,22 @@ func TestFilterFunc(t *testing.T) { { name: "lowQueueAndLessThanKVCacheThresholdPredicate", f: toFilterFunc(queueThresholdPredicate(0).and(kvCacheThresholdPredicate(0.8))), - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ // This pod should be returned. Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ // Queue is non zero, despite low kv cache, should not return. Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.3, }, }, - { + &types.PodMetrics{ // High kv cache despite zero queue, should not return Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, @@ -178,8 +172,8 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, @@ -191,13 +185,11 @@ func TestFilterFunc(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewContext(context.Background(), test.req, test.input) - got, err := test.f(ctx, test.input) - if test.err != (err != nil) { - t.Errorf("Unexpected error, got %v, want %v", err, test.err) - } + ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + got := test.f(ctx, test.input) - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.output, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -233,8 +225,8 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { } // Test setup: One affinity pod and one available pod - pods := []*types.PodMetrics{ - { + pods := []types.Pod{ + &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "affinity-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, @@ -243,7 +235,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, }, - { + &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "available-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, @@ -251,7 +243,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, } - ctx := types.NewContext(context.Background(), req, pods) + ctx := types.NewSchedulingContext(context.Background(), req, pods) // Run the filter function multiple times and count the results affinityCount := 0 @@ -262,10 +254,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { expectedAvailabilityPercent := 100 - expectedAffinityPercent for i := 0; i < numIterations; i++ { - result, err := loRASoftAffinityFilterFunc(ctx, pods) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + result := loRASoftAffinityFilterFunc(ctx, pods) // Check which type of pod was returned if len(result) != 1 { diff --git a/pkg/epp/scheduling/plugins/plugins.go b/pkg/epp/scheduling/plugins/plugins.go new file mode 100644 index 000000000..689785582 --- /dev/null +++ b/pkg/epp/scheduling/plugins/plugins.go @@ -0,0 +1,60 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// Plugin defines the interface for scheduler plugins, combining scoring, filtering, +// and event handling capabilities. +type Plugin interface { + // Name returns the name of the plugin. + Name() string +} + +// PreSchedule is called when the scheduler receives a new request. It can be used for various +// initialization work. +type PreSchedule interface { + Plugin + PreSchedule(ctx *types.SchedulingContext) +} + +// Filter defines the interface for filtering a list of pods based on context. +type Filter interface { + Plugin + Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod +} + +// Scorer defines the interface for scoring pods based on context. +type Scorer interface { + Plugin + Score(ctx *types.SchedulingContext, pod types.Pod) (float64, error) +} + +// PostSchedule is called by the scheduler after it selects a targetPod for the request. +type PostSchedule interface { + Plugin + PostSchedule(ctx *types.SchedulingContext, res *types.Result) +} + +// PostResponse is called by the scheduler after a successful response was sent. +// The given pod argument is the pod that served the request. +type PostResponse interface { + Plugin + PostResponse(ctx *types.SchedulingContext, pod types.Pod) +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 8679ffbad..d841e9c7f 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -24,9 +24,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -95,24 +95,13 @@ var ( current: hasCapacityFilter, nextOnSuccess: lowLatencyFilter, // If all pods are queuing or running above the KVCache threshold, we drop the sheddable - // request to make room for critical requests. - nextOnFailure: dropRequestFilter, + // request to make room for critical requests. for this, we don't define nextOnFailure. } hasCapacityFilter = &basicFilter{ name: "has capacity for sheddable requests", filter: toFilterFunc(queueThresholdPredicate(config.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.KVCacheThreshold))), } - - dropRequestFilter = &basicFilter{ - name: "drop request", - filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) - return []*types.PodMetrics{}, errutil.Error{ - Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", - } - }, - } ) func NewScheduler(datastore Datastore) *Scheduler { @@ -125,8 +114,8 @@ func NewScheduler(datastore Datastore) *Scheduler { type Scheduler struct { datastore Datastore - criticalRequestFilter Filter - sheddableRequestFilter Filter + criticalRequestFilter plugins.Filter + sheddableRequestFilter plugins.Filter } type Datastore interface { @@ -134,27 +123,27 @@ type Datastore interface { } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (targetPod types.Pod, err error) { +func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (res *types.Result, err error) { logger := log.FromContext(ctx).WithValues("request", req) // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) + sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) - var filter Filter + var filter plugins.Filter if req.Critical { filter = s.criticalRequestFilter } else { filter = s.sheddableRequestFilter } - pods, err := filter.Filter(sCtx, sCtx.PodsSnapshot) - if err != nil || len(pods) == 0 { + pods := filter.Filter(sCtx, sCtx.PodsSnapshot) + if len(pods) == 0 { return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) } logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) i := rand.Intn(len(pods)) - return pods[i], nil + return &types.Result{TargetPod: pods[i]}, nil } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 3fd3fb244..04a589ba7 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -31,7 +31,7 @@ func TestSchedule(t *testing.T) { name string req *types.LLMRequest input []*backendmetrics.FakePodMetrics - output types.Pod + output *types.Result err bool }{ { @@ -80,17 +80,19 @@ func TestSchedule(t *testing.T) { }, }, }, - output: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, + output: &types.Result{ + TargetPod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, }, }, @@ -139,17 +141,19 @@ func TestSchedule(t *testing.T) { }, }, }, - output: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + output: &types.Result{ + TargetPod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, }, }, @@ -212,7 +216,8 @@ func TestSchedule(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.output, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 9450652ed..f22d9ee39 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -35,20 +35,26 @@ type LLMRequest struct { Critical bool } -// Context holds contextual information during a scheduling operation. -type Context struct { - context.Context - Logger logr.Logger - Req *LLMRequest - PodsSnapshot []*PodMetrics +func (r *LLMRequest) String() string { + return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical) } type Pod interface { GetPod() *backendmetrics.Pod GetMetrics() *backendmetrics.Metrics + SetScore(float64) + Score() float64 String() string } +// SchedulingContext holds contextual information during a scheduling operation. +type SchedulingContext struct { + context.Context + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []Pod +} + func (pm *PodMetrics) String() string { if pm == nil { return "" @@ -64,14 +70,23 @@ func (pm *PodMetrics) GetMetrics() *backendmetrics.Metrics { return pm.Metrics } +func (pm *PodMetrics) SetScore(score float64) { + pm.score = score +} + +func (pm *PodMetrics) Score() float64 { + return pm.score +} + type PodMetrics struct { + score float64 *backendmetrics.Pod *backendmetrics.Metrics } -func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Context { +func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) - return &Context{ + return &SchedulingContext{ Context: ctx, Logger: logger, Req: req, @@ -79,10 +94,15 @@ func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Conte } } -func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []*PodMetrics { - pm := make([]*PodMetrics, 0, len(pods)) - for _, pod := range pods { - pm = append(pm, &PodMetrics{pod.GetPod().Clone(), pod.GetMetrics().Clone()}) +func ToSchedulerPodMetrics(podsMetrics []backendmetrics.PodMetrics) []Pod { + pods := make([]Pod, 0, len(podsMetrics)) + for _, pod := range podsMetrics { + pods = append(pods, &PodMetrics{Pod: pod.GetPod().Clone(), Metrics: pod.GetMetrics().Clone()}) } - return pm + return pods +} + +// Result captures the scheduler result. +type Result struct { + TargetPod Pod }