From 256acf8462e50b892dd730fbba4bf1da1f9e0a41 Mon Sep 17 00:00:00 2001 From: Radhika Lakhtakia Date: Fri, 18 Apr 2025 18:42:46 +0000 Subject: [PATCH 01/20] Add unit test coverage for pod APIs under datastore/pkg --- pkg/epp/datastore/datastore_test.go | 121 ++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index abbff429a..70fecbad0 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -27,7 +27,10 @@ import ( "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/config" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" @@ -340,3 +343,121 @@ func TestMetrics(t *testing.T) { }) } } + +func TestPods(t *testing.T) { + poolSelector := map[string]string{"app": "vllm_v1"} + pool := testutil.MakeInferencePool("pool"). + Namespace("default"). + Selector(poolSelector).ObjRef() + updatedPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + }, + Spec: corev1.PodSpec{ + NodeName: "node-1", + }, + } + notReadyPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod2", + }, + Status: corev1.PodStatus{ + Conditions: []corev1.PodCondition{{Type: corev1.PodReady, Status: corev1.ConditionFalse}}, + }, + } + tests := []struct { + name string + op func(ctx context.Context, ds Datastore) + existingPods []*corev1.Pod + wantPods []*corev1.Pod + }{ + { + name: "Add new pod, no existing pods, should add", + existingPods: []*corev1.Pod{}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(pod1, pool) + }, + }, + { + name: "Add new pod, with existing pods, should add", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{pod1, pod2}, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(pod2, pool) + }, + }, + { + name: "Update existing pod, new field, should update", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{updatedPod}, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(updatedPod, pool) + }, + }, + { + name: "Update existing pod, no new fields, should not update", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + incoming := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: "default", + }, + } + ds.PodUpdateOrAddIfNotExist(incoming, pool) + }, + }, + { + name: "Add not ready pod, resync required, should update", + existingPods: []*corev1.Pod{pod1, notReadyPod}, + wantPods: []*corev1.Pod{pod1, pod2}, + op: func(ctx context.Context, ds Datastore) { + scheme := runtime.NewScheme() + cfg := config.GetConfigOrDie() + cli, err := client.New(cfg, client.Options{Scheme: scheme}) + if err != nil { + t.Fatalf("Unable to create ctrl runtime client") + } + ds.PodResyncAll(ctx, cli, pool) + }, + }, + { + name: "Delete the pod", + existingPods: []*corev1.Pod{pod1, pod2}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + ds.PodDelete(pod2NamespacedName) + }, + }, + { + name: "Delete the pod that doesn't exist", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + ds.PodDelete(pod2NamespacedName) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := NewDatastore(t.Context(), pmf) + for _, pod := range test.existingPods { + ds.PodUpdateOrAddIfNotExist(pod, pool) + } + + test.op(ctx, ds) + var gotPods []*corev1.Pod + for _, pm := range ds.PodGetAll() { + pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}} + gotPods = append(gotPods, pod) + } + if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) { + t.Logf("got (%v) != want (%v);", gotPods, test.wantPods) + } + }) + } +} From 0b7f5376f1ff51e7153e8971985627627acc15ed Mon Sep 17 00:00:00 2001 From: Radhika Lakhtakia Date: Fri, 18 Apr 2025 18:42:46 +0000 Subject: [PATCH 02/20] Add unit test coverage for pod APIs under datastore/pkg --- pkg/epp/datastore/datastore_test.go | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 70fecbad0..4c4d4ee1b 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -357,14 +357,10 @@ func TestPods(t *testing.T) { NodeName: "node-1", }, } - notReadyPod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pod2", - }, - Status: corev1.PodStatus{ - Conditions: []corev1.PodCondition{{Type: corev1.PodReady, Status: corev1.ConditionFalse}}, - }, - } + resyncPoolSelector := map[string]string{"app": "llama3_8b"} + resyncPool := testutil.MakeInferencePool("pool"). + Namespace("default"). + Selector(resyncPoolSelector).ObjRef() tests := []struct { name string op func(ctx context.Context, ds Datastore) @@ -410,8 +406,8 @@ func TestPods(t *testing.T) { }, }, { - name: "Add not ready pod, resync required, should update", - existingPods: []*corev1.Pod{pod1, notReadyPod}, + name: "Change pool selector, resync required, should update", + existingPods: []*corev1.Pod{pod1, pod2}, wantPods: []*corev1.Pod{pod1, pod2}, op: func(ctx context.Context, ds Datastore) { scheme := runtime.NewScheme() @@ -420,12 +416,8 @@ func TestPods(t *testing.T) { if err != nil { t.Fatalf("Unable to create ctrl runtime client") } - ds.PodResyncAll(ctx, cli, pool) - }, - }, - { + ds.PodResyncAll(ctx, cli, resyncPool) name: "Delete the pod", - existingPods: []*corev1.Pod{pod1, pod2}, wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { ds.PodDelete(pod2NamespacedName) From 2ba431d535a64254ee34d2375ee2a61f87ef1760 Mon Sep 17 00:00:00 2001 From: Radhika Lakhtakia Date: Fri, 18 Apr 2025 18:42:46 +0000 Subject: [PATCH 03/20] Add unit test coverage for pod APIs under datastore/pkg --- pkg/epp/datastore/datastore_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 4c4d4ee1b..a3b60d3e5 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -417,8 +417,10 @@ func TestPods(t *testing.T) { t.Fatalf("Unable to create ctrl runtime client") } ds.PodResyncAll(ctx, cli, resyncPool) - name: "Delete the pod", - wantPods: []*corev1.Pod{pod1}, + }, + }, { + name: "Delete the pod", + wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { ds.PodDelete(pod2NamespacedName) }, From 5067fac6146b8e026aad0a97eea47aaaf969cc69 Mon Sep 17 00:00:00 2001 From: Radhika Lakhtakia Date: Fri, 18 Apr 2025 18:42:46 +0000 Subject: [PATCH 04/20] Add unit test coverage for pod APIs under datastore/pkg --- pkg/epp/datastore/datastore_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index a3b60d3e5..9e5d5821b 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -418,7 +418,8 @@ func TestPods(t *testing.T) { } ds.PodResyncAll(ctx, cli, resyncPool) }, - }, { + }, + { name: "Delete the pod", wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { From 9b7167726c1207b7cb27c686dbbff78563d7ea66 Mon Sep 17 00:00:00 2001 From: Kellen Swain Date: Mon, 21 Apr 2025 12:01:01 -0700 Subject: [PATCH 05/20] EPP Architecture proposal (#683) * initial changes * Adding to proposal to give a quick barebones definition to refactor * feedback changes * more feedback addressing --- .../00x-epp-compliance-proposal/README.md | 99 +++++++++++++++++++ .../images/epp_arch.svg | 1 + 2 files changed, 100 insertions(+) create mode 100644 docs/proposals/00x-epp-compliance-proposal/README.md create mode 100644 docs/proposals/00x-epp-compliance-proposal/images/epp_arch.svg diff --git a/docs/proposals/00x-epp-compliance-proposal/README.md b/docs/proposals/00x-epp-compliance-proposal/README.md new file mode 100644 index 000000000..48c7720fb --- /dev/null +++ b/docs/proposals/00x-epp-compliance-proposal/README.md @@ -0,0 +1,99 @@ +# Gateway API Inference Extension + +Author(s): @kfswain +## Proposal Status + ***Draft*** + +## Table of Contents + + + +- [Summary](#summary) +- [Goals](#goals) +- [Non-Goals](#non-goals) +- [Proposal](#proposal) + - [Personas](#personas) + - [Inference Platform Admin](#inference-platform-admin) + - [Inference Workload Owner](#workload-owner) + - [Axioms](#axioms) + - [InferencePool](#inferencepool) + - [InferenceModel](#inferencemodel) + - [Spec](#spec) + - [Diagrams](#diagrams) + - [Alternatives](#alternatives) +- [Open Questions](#open-questions) + + + +## Summary + +This proposal seeks to standardize the implementation of an EPP (End-point Picker) for the Inference Gateway extension (also known as Gateway API Inference Extension). Additionally, this proposes to restructure the current implementation of the EPP to be more modular, and approachable. + +## Goals + +- Set a standard on how the EPP & APIs interact +- Settle on common nomenclature for clearer communication +- Allow for modularization of the EPP, to be extended to a user's specific needs + +## Non-Goals + +- Reshaping the current API +- A change in scope of the current project + +## Proposal + +This proposal is not proposing any net new features, instead, we are refactoring our current implementation to better handle more devs, more features, etc. At the time of writing, GIE is currently at v0.3, and that stronger experimental context (along with external feedback) made clear the need this restructure. The image below give a high level view of how our components work together. + +Scheduling Algorithm + +## Overview +At a quick glance, the EPP is being broken into specific layers. The `Data Layer` is of note, as it is a vertical that will be accessed by all the others. The data layer manages the k8s, data, metric & usage data, as well as processing of the above data to determine resource scarcity regimes. + +The other layers are handled in sequential process. Starting with the **Ext-Proc** call. The request is buffered and then sent to the **Routing Layer**, which processes any User defined per-InferenceModel routing rules & request enrichment happening first (at the time of writing that is currently just translating the InferenceModel name to a weight-split actual model). Then _all_ requests pass through the to-be-implemented [**Flow Controller**](https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/674) to ensure that any request entry to the pool adhereing to the guidelines set by the Priority, Fairness, & Queueing configuration. And finally, the **Scheduling Layer** is the load balancing algorithm that intelligently routes requests based on the current state of the InferencePool. + +## Components + +To further expand upon these component layers. We will first break them into `extensible` and `non-extensible` layers. `Non-extensible` layers are intended to be static, and handled on behalf of the user, typically implementing low-opinion infrastructure. + +The `Extensible` layers are: +- Data Layer +- Routing Layer +- Flow Controller +- Scheduling Layer + +The `Non-Extensible` layer(s) are: +- The Ext-Proc Server + +### `Extensible` + +#### Data Layer + +The data layer will consume and store: the InferencePool/InferenceModel config and the pre-defined [Model Server Protocol](../003-model-server-protocol/README.md). Additionally, the data fed from the model servers will be processed and digested to provide resource scarcity regime hints, and autoscaling reccomendations. + +Many extensions to scheduling will require changes to ingested metrics, as such, the data layer will be built to be extended, but extenders accept that the Model Server Protocol will no longer provide guarantees on portability of a model server out of the box. + +#### Routing Layer + +The routing layer is likely to be the most opinion heavy section, as the scope of what constitutes a 'Route Rule' is somewhat broad. The current examples we expect would be: + +- System Prompt injection +- RAG callout +- Per-InferenceModel request validation (such as saftey/on-topic, etc) + +Due to the possibility of this becoming a bit of a dumping ground. The API will keep a _very_ tight scope on which of these route rules are included in the spec. A standard method of extension will be provided if the need to define a custom rule arises. + +#### Flow Controller (WIP - implementation tracked in [#674](https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/674)) + +The flow controller will consume resource regime data, and enforce proper resource sharing between workloads. This will primarily be done through a queuing mechanism [as described here](https://docs.google.com/document/d/1VZL7opFWuwgWquvgiOzLlXAJ633qZ9U-A0ZixGjBgaI/edit?usp=sharing). + +#### Scheduling Layer + +As the Scheduling Layer is the final interface to the entirety of the pool, all configuration will be at the _pool_ level. The default scheduling layer will be an experimentally-backed LB algorithm, with exposed config values. + +The Scheduler will define a strong interface API, so that new scheduling algos may be plugged & dark-launched to test in production traffic without impacting said traffic. Extension is expected to adhere to the [Scheduler Subsystem definition](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/603) + +### `Non-extensible` + +#### Ext-Proc Server + +The Ext-Proc Server protocol is very well defined & specific, deviation could cause the EPP to become unusable or unstable. Extension is ill-advised. diff --git a/docs/proposals/00x-epp-compliance-proposal/images/epp_arch.svg b/docs/proposals/00x-epp-compliance-proposal/images/epp_arch.svg new file mode 100644 index 000000000..4c5857281 --- /dev/null +++ b/docs/proposals/00x-epp-compliance-proposal/images/epp_arch.svg @@ -0,0 +1 @@ + \ No newline at end of file From bfbc35f395c5e3eb6ba536e0f443f6bc2fbbd7cc Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Tue, 22 Apr 2025 18:49:41 +0300 Subject: [PATCH 06/20] removed unused Fake struct (#723) Signed-off-by: Nir Rozenbaum --- pkg/epp/backend/metrics/fake.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go index 7fd4970db..ec97c6dea 100644 --- a/pkg/epp/backend/metrics/fake.go +++ b/pkg/epp/backend/metrics/fake.go @@ -24,7 +24,6 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -84,11 +83,3 @@ func (f *FakePodMetricsClient) SetErr(new map[types.NamespacedName]error) { defer f.errMu.Unlock() f.Err = new } - -type FakeDataStore struct { - Res map[string]*v1alpha2.InferenceModel -} - -func (fds *FakeDataStore) FetchModelData(modelName string) (returnModel *v1alpha2.InferenceModel) { - return fds.Res[modelName] -} From fc980b88acdb9a118eb32f145fef48168f17a2c7 Mon Sep 17 00:00:00 2001 From: John Howard Date: Tue, 22 Apr 2025 14:59:40 -0700 Subject: [PATCH 07/20] epp: return correct response for trailers (#726) This looks like a copy paste error. --- pkg/epp/handlers/server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 7bb0fcb16..f97e9ede0 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -325,7 +325,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces } if r.RequestState == BodyRequestResponsesComplete && r.reqTrailerResp != nil { // Trailers in requests are not guaranteed - if err := srv.Send(r.reqHeaderResp); err != nil { + if err := srv.Send(r.reqTrailerResp); err != nil { return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) } } @@ -351,7 +351,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces } if r.RequestState == BodyResponseResponsesComplete && r.respTrailerResp != nil { // Trailers in requests are not guaranteed - if err := srv.Send(r.reqHeaderResp); err != nil { + if err := srv.Send(r.respTrailerResp); err != nil { return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) } } From 55600b4e44ff1a8007c44d95e2c4e428d616f949 Mon Sep 17 00:00:00 2001 From: Cong Liu Date: Tue, 22 Apr 2025 15:15:41 -0700 Subject: [PATCH 08/20] Refactor scheduler to run plugins (#677) * Refactor scheduler to run plugins * Add scheduler plugin latency metric * Address comments * Address comments --- pkg/epp/backend/metrics/types.go | 6 + pkg/epp/handlers/request.go | 9 +- pkg/epp/handlers/server.go | 2 +- pkg/epp/metrics/metrics.go | 22 ++ pkg/epp/metrics/metrics_test.go | 64 ++++ ...heduler_plugin_processing_latencies_metric | 67 ++++ pkg/epp/scheduling/config/config.go | 58 +++ pkg/epp/scheduling/{ => plugins}/filter.go | 144 ++++---- .../scheduling/{ => plugins}/filter_test.go | 91 ++--- pkg/epp/scheduling/plugins/noop.go | 38 ++ pkg/epp/scheduling/plugins/picker.go | 37 ++ pkg/epp/scheduling/scheduler.go | 236 ++++++++----- pkg/epp/scheduling/scheduler_test.go | 331 ++++++++++++++++-- pkg/epp/scheduling/types/interfaces.go | 75 ++++ pkg/epp/scheduling/types/types.go | 35 +- 15 files changed, 969 insertions(+), 246 deletions(-) create mode 100644 pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric create mode 100644 pkg/epp/scheduling/config/config.go rename pkg/epp/scheduling/{ => plugins}/filter.go (60%) rename pkg/epp/scheduling/{ => plugins}/filter_test.go (82%) create mode 100644 pkg/epp/scheduling/plugins/noop.go create mode 100644 pkg/epp/scheduling/plugins/picker.go create mode 100644 pkg/epp/scheduling/types/interfaces.go diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index 925a0cc5a..21c0f4016 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -79,6 +79,9 @@ func (p *Pod) String() string { } func (p *Pod) Clone() *Pod { + if p == nil { + return nil + } return &Pod{ NamespacedName: types.NamespacedName{ Name: p.NamespacedName.Name, @@ -118,6 +121,9 @@ func (m *Metrics) String() string { } func (m *Metrics) Clone() *Metrics { + if m == nil { + return nil + } cm := make(map[string]int, len(m.ActiveModels)) for k, v := range m.ActiveModels { cm[k] = v diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 44537923d..9121b59af 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -67,7 +67,7 @@ func (s *StreamingServer) HandleRequestBody( ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, } - logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical) + logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) var err error // Update target models in the body. @@ -81,11 +81,11 @@ func (s *StreamingServer) HandleRequestBody( return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} } - target, err := s.scheduler.Schedule(ctx, llmReq) + res, err := s.scheduler.Schedule(ctx, llmReq) if err != nil { return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} } - targetPod := target.GetPod() + targetPod := res.TargetPod.GetPod() // Insert target endpoint to instruct Envoy to route requests to the specified target pod. // Attach the port number @@ -96,8 +96,7 @@ func (s *StreamingServer) HandleRequestBody( endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) logger.V(logutil.DEFAULT).Info("Request handled", - "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics", - fmt.Sprintf("%+v", target)) + "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index f97e9ede0..2e3a35fe7 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -65,7 +65,7 @@ type StreamingServer struct { } type Scheduler interface { - Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error) + Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) } // RequestContext stores context information during the life time of an HTTP request. diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index b474df365..56dcfca8c 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -30,6 +30,7 @@ import ( const ( InferenceModelComponent = "inference_model" InferencePoolComponent = "inference_pool" + EPPComponent = "endpoint_picker" ) var ( @@ -176,6 +177,20 @@ var ( }, []string{"name"}, ) + + // Scheduler Plugin Metrics + SchedulerPluginProcessingLatencies = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: EPPComponent, + Name: "scheduler_plugin_duration_seconds", + Help: "Scheduler plugin processing latency distribution in seconds for each plugin type and plugin name.", + Buckets: []float64{ + 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, + }, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{"plugin_type", "plugin_name"}, + ) ) var registerMetrics sync.Once @@ -196,6 +211,8 @@ func Register() { legacyregistry.MustRegister(inferencePoolAvgKVCache) legacyregistry.MustRegister(inferencePoolAvgQueueSize) legacyregistry.MustRegister(inferencePoolReadyPods) + + legacyregistry.MustRegister(SchedulerPluginProcessingLatencies) }) } @@ -293,3 +310,8 @@ func RecordInferencePoolAvgQueueSize(name string, queueSize float64) { func RecordinferencePoolReadyPods(name string, runningPods float64) { inferencePoolReadyPods.WithLabelValues(name).Set(runningPods) } + +// RecordSchedulerPluginProcessingLatency records the processing latency for a scheduler plugin. +func RecordSchedulerPluginProcessingLatency(pluginType, pluginName string, duration time.Duration) { + SchedulerPluginProcessingLatencies.WithLabelValues(pluginType, pluginName).Observe(duration.Seconds()) +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index b5f19e6d0..81797e6de 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -556,3 +556,67 @@ func TestInferencePoolMetrics(t *testing.T) { }) } } + +func TestSchedulerPluginProcessingLatencies(t *testing.T) { + type pluginLatency struct { + pluginType string + pluginName string + duration time.Duration + } + scenarios := []struct { + name string + latencies []pluginLatency + }{ + { + name: "multiple plugins", + latencies: []pluginLatency{ + { + pluginType: "PreSchedule", + pluginName: "PluginA", + duration: 100 * time.Millisecond, + }, + { + pluginType: "PostSchedule", + pluginName: "PluginB", + duration: 200 * time.Millisecond, + }, + { + pluginType: "Filter", + pluginName: "PluginC", + duration: 50 * time.Millisecond, + }, + { + pluginType: "Scorer", + pluginName: "PluginD", + duration: 10 * time.Millisecond, + }, + { + pluginType: "Picker", + pluginName: "PluginE", + duration: 10 * time.Microsecond, + }, + }, + }, + } + Register() + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + for _, latency := range scenario.latencies { + RecordSchedulerPluginProcessingLatency(latency.pluginType, latency.pluginName, latency.duration) + } + + wantPluginLatencies, err := os.Open("testdata/scheduler_plugin_processing_latencies_metric") + defer func() { + if err := wantPluginLatencies.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantPluginLatencies, "endpoint_picker_scheduler_plugin_processing_latencies"); err != nil { + t.Error(err) + } + }) + } +} diff --git a/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric b/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric new file mode 100644 index 000000000..8c11757f4 --- /dev/null +++ b/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric @@ -0,0 +1,67 @@ +# HELP endpoint_picker_scheduler_plugin_duration_seconds [ALPHA] Scheduler plugin processing latency distribution in seconds for each plugin type and plugin name. +# TYPE endpoint_picker_scheduler_plugin_duration_seconds histogram +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.05"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginA",plugin_type="PreSchedule"} 0.1 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginA",plugin_type="PreSchedule"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.05"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.1"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginB",plugin_type="PostSchedule"} 0.2 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginB",plugin_type="PostSchedule"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginC",plugin_type="Filter"} 0.05 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginC",plugin_type="Filter"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.01"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.02"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginD",plugin_type="Scorer"} 0.01 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginD",plugin_type="Scorer"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0001"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0002"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0005"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.001"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.002"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.005"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.01"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.02"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginE",plugin_type="Picker"} 1e-05 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginE",plugin_type="Picker"} 1 diff --git a/pkg/epp/scheduling/config/config.go b/pkg/epp/scheduling/config/config.go new file mode 100644 index 000000000..e00b82aec --- /dev/null +++ b/pkg/epp/scheduling/config/config.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "sigs.k8s.io/controller-runtime/pkg/log" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// Config holds all the configuration values for the scheduler +type Config struct { + KVCacheThreshold float64 + QueueThresholdCritical int + QueueingThresholdLoRA int + LoraAffinityThreshold float64 +} + +const ( + // Default values to use if environment variables are not set + defaultKVCacheThreshold = 0.8 + defaultQueueThresholdCritical = 5 + defaultQueueingThresholdLoRA = 128 + defaultLoraAffinityThreshold = 0.999 +) + +// LoadConfig loads configuration from environment variables +func LoadConfig() Config { + // Use a default logger for initial configuration loading + baseLogger := log.Log.WithName("scheduling-config") + + config := Config{ + KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), + QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), + QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), + LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), + } + + baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) + + return config +} + +var Conf = LoadConfig() diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/plugins/filter.go similarity index 60% rename from pkg/epp/scheduling/filter.go rename to pkg/epp/scheduling/plugins/filter.go index 99044e976..efcb6be17 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/plugins/filter.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package scheduling +package plugins import ( "errors" @@ -22,83 +22,80 @@ import ( "math/rand" "time" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -type Filter interface { - Name() string - Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) -} - -type basicFilter struct { +type Filter struct { name string filter filterFunc } -func (bf *basicFilter) Name() string { +func (bf *Filter) Name() string { if bf == nil { return "nil" } return bf.name } -func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func (bf *Filter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { loggerTrace := ctx.Logger.V(logutil.TRACE) loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods)) return bf.filter(ctx, pods) } -// decisionTreeFilter applies current filterFunc, and then recursively applies next filters +// DecisionTreeFilter applies current filterFunc, and then recursively applies next filters // depending success or failure of the current filter. // It can be used to construct a flow chart algorithm. -type decisionTreeFilter struct { - current Filter - // nextOnSuccess filter will be applied after successfully applying the current filter. +type DecisionTreeFilter struct { + Current types.Filter + // NextOnSuccess filter will be applied after successfully applying the current filter. // The filtered results will be passed to the next filter. - nextOnSuccess Filter - // nextOnFailure filter will be applied if current filter fails. + NextOnSuccess types.Filter + // NextOnFailure filter will be applied if current filter fails. // The original input will be passed to the next filter. - nextOnFailure Filter - // nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the + NextOnFailure types.Filter + // NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the // success or failure of the current filter. - // NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. + // NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of - // nextOnSuccessOrFailure, in the success and failure scenarios, respectively. - nextOnSuccessOrFailure Filter + // NextOnSuccessOrFailure, in the success and failure scenarios, respectively. + NextOnSuccessOrFailure types.Filter } -func (f *decisionTreeFilter) Name() string { +func (f *DecisionTreeFilter) Name() string { if f == nil { return "nil" } - return f.current.Name() + return f.Current.Name() } -func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { loggerTrace := ctx.Logger.V(logutil.TRACE) - filtered, err := f.current.Filter(ctx, pods) + filtered, err := f.Current.Filter(ctx, pods) - next := f.nextOnSuccessOrFailure + next := f.NextOnSuccessOrFailure if err == nil && len(filtered) > 0 { - if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil { + if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. return filtered, err } - if f.nextOnSuccess != nil { - next = f.nextOnSuccess + if f.NextOnSuccess != nil { + next = f.NextOnSuccess } loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filtered)) // On success, pass the filtered result to the next filter. return next.Filter(ctx, filtered) } else { - if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil { + if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. return filtered, err } - if f.nextOnFailure != nil { - next = f.nextOnFailure + if f.NextOnFailure != nil { + next = f.NextOnFailure } loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name()) // On failure, pass the initial set of pods to the next filter. @@ -107,12 +104,12 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics } // filterFunc filters a set of input pods to a subset. -type filterFunc func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) +type filterFunc func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - filtered := []*types.PodMetrics{} + return func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + filtered := []types.Pod{} for _, pod := range pods { pass := pp(ctx.Req, pod) if pass { @@ -126,7 +123,7 @@ func toFilterFunc(pp podPredicate) filterFunc { } } -var leastQueueFilter = &basicFilter{ +var LeastQueueFilter = &Filter{ name: "least queuing", filter: leastQueuingFilterFunc, } @@ -138,34 +135,34 @@ var leastQueueFilter = &basicFilter{ // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func leastQueuingFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { min := math.MaxInt max := 0 - filtered := []*types.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { - if pod.WaitingQueueSize <= min { - min = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize <= min { + min = pod.GetMetrics().WaitingQueueSize } - if pod.WaitingQueueSize >= max { - max = pod.WaitingQueueSize + if pod.GetMetrics().WaitingQueueSize >= max { + max = pod.GetMetrics().WaitingQueueSize } } for _, pod := range pods { - if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) { + if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) { filtered = append(filtered, pod) } } return filtered, nil } -var lowQueueFilter = &basicFilter{ +var LowQueueFilter = &Filter{ name: "low queueing filter", - filter: toFilterFunc((queueThresholdPredicate(config.QueueingThresholdLoRA))), + filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))), } -var leastKVCacheFilter = &basicFilter{ +var LeastKVCacheFilter = &Filter{ name: "least KV cache percent", filter: leastKVCacheFilterFunc, } @@ -176,29 +173,29 @@ var leastKVCacheFilter = &basicFilter{ // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func leastKVCacheFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { min := math.MaxFloat64 var max float64 = 0 - filtered := []*types.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { - if pod.KVCacheUsagePercent <= min { - min = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent <= min { + min = pod.GetMetrics().KVCacheUsagePercent } - if pod.KVCacheUsagePercent >= max { - max = pod.KVCacheUsagePercent + if pod.GetMetrics().KVCacheUsagePercent >= max { + max = pod.GetMetrics().KVCacheUsagePercent } } for _, pod := range pods { - if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { + if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) { filtered = append(filtered, pod) } } return filtered, nil } -var loRAAffinityFilter = &basicFilter{ +var LoRAAffinityFilter = &Filter{ name: "affinity LoRA", filter: loRASoftAffinityFilterFunc, } @@ -219,20 +216,20 @@ var loRAAffinityFilter = &basicFilter{ // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { +func loRASoftAffinityFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { // Pre-allocate slices with estimated capacity - filtered_affinity := make([]*types.PodMetrics, 0, len(pods)) - filtered_available := make([]*types.PodMetrics, 0, len(pods)) + filtered_affinity := make([]types.Pod, 0, len(pods)) + filtered_available := make([]types.Pod, 0, len(pods)) // Categorize pods based on affinity and availability for _, pod := range pods { - _, active := pod.ActiveModels[ctx.Req.ResolvedTargetModel] - _, waiting := pod.WaitingModels[ctx.Req.ResolvedTargetModel] + _, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel] + _, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel] if active || waiting { filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.ActiveModels)+len(pod.WaitingModels) < pod.MaxActiveModels { + } else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels { filtered_available = append(filtered_available, pod) } } @@ -243,7 +240,7 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([ // If both groups have pods, use probability to select which group to return if len(filtered_affinity) > 0 && len(filtered_available) > 0 { - if randGen.Float64() < config.LoraAffinityThreshold { + if randGen.Float64() < config.Conf.LoraAffinityThreshold { return filtered_affinity, nil } return filtered_available, nil @@ -257,23 +254,38 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([ return filtered_available, nil } +var HasCapacityFilter = &Filter{ + name: "has capacity for sheddable requests", + filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), +} + +var DropRequestFilter = &Filter{ + name: "drop request", + filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) + return []types.Pod{}, errutil.Error{ + Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", + } + }, +} + // podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *types.LLMRequest, pod *types.PodMetrics) bool +type podPredicate func(req *types.LLMRequest, pod types.Pod) bool func queueThresholdPredicate(queueThreshold int) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { - return pod.WaitingQueueSize <= queueThreshold + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().WaitingQueueSize <= queueThreshold } } func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { - return pod.KVCacheUsagePercent <= kvCacheThreshold + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold } } func (pp podPredicate) and(another podPredicate) podPredicate { - return func(req *types.LLMRequest, pod *types.PodMetrics) bool { + return func(req *types.LLMRequest, pod types.Pod) bool { return pp(req, pod) && another(req, pod) } } diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/plugins/filter_test.go similarity index 82% rename from pkg/epp/scheduling/filter_test.go rename to pkg/epp/scheduling/plugins/filter_test.go index 543826d06..107b423fb 100644 --- a/pkg/epp/scheduling/filter_test.go +++ b/pkg/epp/scheduling/plugins/filter_test.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package scheduling +package plugins import ( "context" @@ -24,6 +24,7 @@ import ( "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -31,17 +32,17 @@ func TestFilter(t *testing.T) { tests := []struct { name string req *types.LLMRequest - input []*types.PodMetrics - output []*types.PodMetrics + input []types.Pod + output []types.Pod err bool - filter *decisionTreeFilter + filter *DecisionTreeFilter }{ { name: "simple filter without successor, failure", - filter: &decisionTreeFilter{ - current: &basicFilter{ + filter: &DecisionTreeFilter{ + Current: &Filter{ name: "error", - filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { + filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { return nil, errors.New("filter error") }, }, @@ -58,7 +59,8 @@ func TestFilter(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.output, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -70,43 +72,43 @@ func TestFilterFunc(t *testing.T) { name string f filterFunc req *types.LLMRequest - input []*types.PodMetrics - output []*types.PodMetrics + input []types.Pod + output []types.Pod err bool }{ { name: "least queuing empty input", f: leastQueuingFilterFunc, - input: []*types.PodMetrics{}, - output: []*types.PodMetrics{}, + input: []types.Pod{}, + output: []types.Pod{}, }, { name: "least queuing", f: leastQueuingFilterFunc, - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 10, }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 3, }, @@ -116,36 +118,36 @@ func TestFilterFunc(t *testing.T) { { name: "least kv cache empty input", f: leastKVCacheFilterFunc, - input: []*types.PodMetrics{}, - output: []*types.PodMetrics{}, + input: []types.Pod{}, + output: []types.Pod{}, }, { name: "least kv cache", f: leastKVCacheFilterFunc, - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 1.0, }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ KVCacheUsagePercent: 0.3, }, @@ -155,22 +157,22 @@ func TestFilterFunc(t *testing.T) { { name: "lowQueueAndLessThanKVCacheThresholdPredicate", f: toFilterFunc(queueThresholdPredicate(0).and(kvCacheThresholdPredicate(0.8))), - input: []*types.PodMetrics{ - { + input: []types.Pod{ + &types.PodMetrics{ // This pod should be returned. Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, }, }, - { + &types.PodMetrics{ // Queue is non zero, despite low kv cache, should not return. Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.3, }, }, - { + &types.PodMetrics{ // High kv cache despite zero queue, should not return Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, @@ -178,8 +180,8 @@ func TestFilterFunc(t *testing.T) { }, }, }, - output: []*types.PodMetrics{ - { + output: []types.Pod{ + &types.PodMetrics{ Metrics: &backendmetrics.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, @@ -197,7 +199,8 @@ func TestFilterFunc(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.output, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -215,15 +218,15 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { ) // Save original config value to restore later - originalThreshold := config.LoraAffinityThreshold + originalThreshold := config.Conf.LoraAffinityThreshold // Set a specific test value for this test testThreshold := 0.75 // 75% - config.LoraAffinityThreshold = testThreshold + config.Conf.LoraAffinityThreshold = testThreshold // Ensure we restore the original threshold when test completes defer func() { - config.LoraAffinityThreshold = originalThreshold + config.Conf.LoraAffinityThreshold = originalThreshold }() // Create a test request and pods @@ -233,8 +236,8 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { } // Test setup: One affinity pod and one available pod - pods := []*types.PodMetrics{ - { + pods := []types.Pod{ + &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "affinity-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, @@ -243,7 +246,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, }, - { + &types.PodMetrics{ Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "available-pod"}}, Metrics: &backendmetrics.Metrics{ MaxActiveModels: 2, @@ -258,7 +261,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { availableCount := 0 // Use the test threshold value - expectedAffinityPercent := config.LoraAffinityThreshold * 100 + expectedAffinityPercent := config.Conf.LoraAffinityThreshold * 100 expectedAvailabilityPercent := 100 - expectedAffinityPercent for i := 0; i < numIterations; i++ { @@ -292,8 +295,8 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { availableUpperBound := expectedAvailabilityPercent + tolerancePercent t.Logf("Distribution results over %d iterations:", numIterations) - t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.LoraAffinityThreshold) - t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.LoraAffinityThreshold) + t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.Conf.LoraAffinityThreshold) + t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.Conf.LoraAffinityThreshold) t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations) t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations) diff --git a/pkg/epp/scheduling/plugins/noop.go b/pkg/epp/scheduling/plugins/noop.go new file mode 100644 index 000000000..1abcb95b1 --- /dev/null +++ b/pkg/epp/scheduling/plugins/noop.go @@ -0,0 +1,38 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// NoopPlugin provides a default, no-operation implementation of the Plugin interface. +// It can be embedded in other plugin implementations to avoid boilerplate code for +// unused methods. +type NoopPlugin struct{} + +func (p *NoopPlugin) Name() string { return "NoopPlugin" } + +func (p *NoopPlugin) Score(ctx *types.Context, pod types.Pod) (float64, error) { return 0.0, nil } + +func (p *NoopPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + return pods, nil +} + +func (p *NoopPlugin) PreSchedule(ctx *types.Context) {} + +func (p *NoopPlugin) PostSchedule(ctx *types.Context, res *types.Result) {} diff --git a/pkg/epp/scheduling/plugins/picker.go b/pkg/epp/scheduling/plugins/picker.go new file mode 100644 index 000000000..569e4e86a --- /dev/null +++ b/pkg/epp/scheduling/plugins/picker.go @@ -0,0 +1,37 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "fmt" + "math/rand" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +type RandomPicker struct{} + +func (rp *RandomPicker) Name() string { + return "random" +} + +func (rp *RandomPicker) Pick(ctx *types.Context, pods []types.Pod) (*types.Result, error) { + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) + i := rand.Intn(len(pods)) + return &types.Result{TargetPod: pods[i]}, nil +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 8679ffbad..7cc2bd968 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -20,113 +20,71 @@ package scheduling import ( "context" "fmt" - "math/rand" + "time" "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// Config holds all the configuration values for the scheduler -type Config struct { - KVCacheThreshold float64 - QueueThresholdCritical int - QueueingThresholdLoRA int - LoraAffinityThreshold float64 -} - -const ( - // Default values to use if environment variables are not set - defaultKVCacheThreshold = 0.8 - defaultQueueThresholdCritical = 5 - defaultQueueingThresholdLoRA = 128 - defaultLoraAffinityThreshold = 0.999 -) - -// LoadConfig loads configuration from environment variables -func LoadConfig() Config { - // Use a default logger for initial configuration loading - baseLogger := log.Log.WithName("scheduling-config") - - config := Config{ - KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), - QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), - QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), - LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), - } - - baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) - - return config -} - -var config = LoadConfig() - var ( - lowLatencyFilter = &decisionTreeFilter{ - current: lowQueueFilter, - nextOnSuccess: &decisionTreeFilter{ - current: loRAAffinityFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: leastQueueFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: leastKVCacheFilter, + lowLatencyFilter = &plugins.DecisionTreeFilter{ + Current: plugins.LowQueueFilter, + NextOnSuccess: &plugins.DecisionTreeFilter{ + Current: plugins.LoRAAffinityFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastQueueFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastKVCacheFilter, }, }, }, - nextOnFailure: &decisionTreeFilter{ - current: leastQueueFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: loRAAffinityFilter, - nextOnSuccessOrFailure: &decisionTreeFilter{ - current: leastKVCacheFilter, + NextOnFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastQueueFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LoRAAffinityFilter, + NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ + Current: plugins.LeastKVCacheFilter, }, }, }, } - sheddableRequestFilter = &decisionTreeFilter{ + sheddableRequestFilter = &plugins.DecisionTreeFilter{ // When there is at least one model server that's not queuing requests, and still has KV // cache below a certain threshold, we consider this model server has capacity to handle // a sheddable request without impacting critical requests. - current: hasCapacityFilter, - nextOnSuccess: lowLatencyFilter, + Current: plugins.HasCapacityFilter, + NextOnSuccess: lowLatencyFilter, // If all pods are queuing or running above the KVCache threshold, we drop the sheddable // request to make room for critical requests. - nextOnFailure: dropRequestFilter, - } - - hasCapacityFilter = &basicFilter{ - name: "has capacity for sheddable requests", - filter: toFilterFunc(queueThresholdPredicate(config.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.KVCacheThreshold))), - } - - dropRequestFilter = &basicFilter{ - name: "drop request", - filter: func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) { - ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) - return []*types.PodMetrics{}, errutil.Error{ - Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", - } - }, + NextOnFailure: plugins.DropRequestFilter, } ) func NewScheduler(datastore Datastore) *Scheduler { + defaultPlugin := &defaultPlugin{} + return &Scheduler{ - datastore: datastore, - criticalRequestFilter: lowLatencyFilter, - sheddableRequestFilter: sheddableRequestFilter, + datastore: datastore, + preSchedulePlugins: []types.PreSchedule{}, + postSchedulePlugins: []types.PostSchedule{}, + scorers: []types.Scorer{}, + filters: []types.Filter{defaultPlugin}, + picker: defaultPlugin, } } type Scheduler struct { - datastore Datastore - criticalRequestFilter Filter - sheddableRequestFilter Filter + datastore Datastore + preSchedulePlugins []types.PreSchedule + postSchedulePlugins []types.PostSchedule + filters []types.Filter + scorers []types.Scorer + picker types.Picker } type Datastore interface { @@ -134,27 +92,125 @@ type Datastore interface { } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (targetPod types.Pod, err error) { +func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("request", req) + loggerDebug := logger.V(logutil.DEBUG) // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. sCtx := types.NewContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) + loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) - var filter Filter - if req.Critical { - filter = s.criticalRequestFilter - } else { - filter = s.sheddableRequestFilter + s.runPreSchedulePlugins(sCtx) + + pods, err := s.runFilterPlugins(sCtx) + if err != nil { + return nil, err + } + + if err := s.runScorerPlugins(sCtx, pods); err != nil { + return nil, err + } + + before := time.Now() + res, err := s.picker.Pick(sCtx, pods) + metrics.RecordSchedulerPluginProcessingLatency(types.PickerPluginType, s.picker.Name(), time.Since(before)) + if err != nil { + return nil, err } + loggerDebug.Info("After running picker plugins", "result", res) - pods, err := filter.Filter(sCtx, sCtx.PodsSnapshot) - if err != nil || len(pods) == 0 { - return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) + s.runPostSchedulePlugins(sCtx, res) + + return res, nil +} + +func (s *Scheduler) runPreSchedulePlugins(ctx *types.Context) { + for _, plugin := range s.preSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running pre-schedule plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PreSchedule(ctx) + metrics.RecordSchedulerPluginProcessingLatency(types.PreSchedulerPluginType, plugin.Name(), time.Since(before)) + } +} + +func (s *Scheduler) runPostSchedulePlugins(ctx *types.Context, res *types.Result) { + for _, plugin := range s.postSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostSchedule(ctx, res) + metrics.RecordSchedulerPluginProcessingLatency(types.PostSchedulePluginType, plugin.Name(), time.Since(before)) + } +} + +func (s *Scheduler) runFilterPlugins(ctx *types.Context) ([]types.Pod, error) { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + pods := ctx.PodsSnapshot + loggerDebug.Info("Before running filter plugins", "pods", pods) + for _, filter := range s.filters { + loggerDebug.Info("Running filter plugin", "plugin", filter.Name()) + before := time.Now() + filteredPods, err := filter.Filter(ctx, pods) + metrics.RecordSchedulerPluginProcessingLatency(types.FilterPluginType, filter.Name(), time.Since(before)) + if err != nil || len(filteredPods) == 0 { + return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(filteredPods), err) + } + pods = filteredPods + loggerDebug.Info("Filter plugin result", "plugin", filter.Name(), "pods", pods) + } + loggerDebug.Info("After running filter plugins", "pods", pods) + return pods, nil +} + +func (s *Scheduler) runScorerPlugins(ctx *types.Context, pods []types.Pod) error { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + loggerDebug.Info("Before running score plugins", "pods", pods) + for _, pod := range pods { + score, err := runScorersForPod(ctx, s.scorers, pod) + if err != nil { + return err + } + pod.SetScore(score) + } + loggerDebug.Info("After running score plugins", "pods", pods) + return nil +} + +// Iterate through each scorer in the chain and accumulate the scores. +func runScorersForPod(ctx *types.Context, scorers []types.Scorer, pod types.Pod) (float64, error) { + logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG) + score := float64(0) + for _, scorer := range scorers { + logger.Info("Running scorer", "scorer", scorer.Name()) + before := time.Now() + oneScore, err := scorer.Score(ctx, pod) + metrics.RecordSchedulerPluginProcessingLatency(types.ScorerPluginType, scorer.Name(), time.Since(before)) + if err != nil { + logger.Error(err, "Failed to calculate score for scorer", "scorer", scorer.Name()) + return 0, err + } + score += oneScore + logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score) + } + return score, nil +} + +type defaultPlugin struct { + plugins.RandomPicker +} + +func (p *defaultPlugin) Name() string { + return "DefaultPlugin" +} + +func (p *defaultPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + req := ctx.Req + var filter types.Filter + if req.Critical { + filter = lowLatencyFilter + } else { + filter = sheddableRequestFilter } - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) - i := rand.Intn(len(pods)) - return pods[i], nil + return filter.Filter(ctx, pods) } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 3fd3fb244..5a2265bff 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -18,22 +18,34 @@ package scheduling import ( "context" + "errors" "testing" "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +// Tests the default scheduler configuration and expected behavior. func TestSchedule(t *testing.T) { tests := []struct { - name string - req *types.LLMRequest - input []*backendmetrics.FakePodMetrics - output types.Pod - err bool + name string + req *types.LLMRequest + input []*backendmetrics.FakePodMetrics + wantRes *types.Result + err bool }{ + { + name: "no pods in datastore", + req: &types.LLMRequest{ + Model: "any-model", + ResolvedTargetModel: "any-model", + Critical: true, + }, + input: []*backendmetrics.FakePodMetrics{}, + err: true, + }, { name: "critical request", req: &types.LLMRequest{ @@ -80,17 +92,19 @@ func TestSchedule(t *testing.T) { }, }, }, - output: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, + wantRes: &types.Result{ + TargetPod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, }, }, @@ -139,17 +153,19 @@ func TestSchedule(t *testing.T) { }, }, }, - output: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + wantRes: &types.Result{ + TargetPod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, }, }, @@ -199,8 +215,8 @@ func TestSchedule(t *testing.T) { }, }, }, - output: nil, - err: true, + wantRes: nil, + err: true, }, } @@ -212,13 +228,205 @@ func TestSchedule(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + opt := cmp.AllowUnexported(types.PodMetrics{}) + if diff := cmp.Diff(test.wantRes, got, opt); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) } } +func TestSchedulePlugins(t *testing.T) { + tp1 := &TestPlugin{ + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + } + tp2 := &TestPlugin{ + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + } + tpFilterErr := &TestPlugin{ + NameRes: "filter err", + FilterErr: errors.New("filter error"), + } + tpScorerErr := &TestPlugin{ + NameRes: "score err", + ScoreErr: errors.New("score err"), + } + pickerPlugin := &TestPlugin{ + NameRes: "picker", + PickRes: k8stypes.NamespacedName{Name: "pod1"}, + } + pickerErr := &TestPlugin{ + NameRes: "picker err", + PickErr: errors.New("picker err"), + } + + tests := []struct { + name string + preSchedulePlugins []types.PreSchedule + postSchedulePlugins []types.PostSchedule + filters []types.Filter + scorers []types.Scorer + picker types.Picker + input []*backendmetrics.FakePodMetrics + wantTargetPod k8stypes.NamespacedName + targetPodScore float64 + // Number of expected pods to score (after filter) + numPodsToScore int + err bool + }{ + { + name: "all plugins executed successfully", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tp2}, + scorers: []types.Scorer{tp1, tp2}, + picker: pickerPlugin, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, + }, + { + name: "filter error", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tpFilterErr}, + scorers: []types.Scorer{tp1, tp2}, + picker: pickerPlugin, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + err: true, + }, + { + name: "scorer error", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tp2}, + scorers: []types.Scorer{tp1, tpScorerErr}, + picker: pickerPlugin, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + err: true, + }, + { + name: "picker error", + preSchedulePlugins: []types.PreSchedule{tp1, tp2}, + postSchedulePlugins: []types.PostSchedule{tp1, tp2}, + filters: []types.Filter{tp1, tp2}, + scorers: []types.Scorer{tp1, tp2}, + picker: pickerErr, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + err: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Reset all plugins before each new test case. + for _, plugin := range test.preSchedulePlugins { + plugin.(*TestPlugin).Reset() + } + for _, plugin := range test.postSchedulePlugins { + plugin.(*TestPlugin).Reset() + } + for _, plugin := range test.filters { + plugin.(*TestPlugin).Reset() + } + for _, plugin := range test.scorers { + plugin.(*TestPlugin).Reset() + } + test.picker.(*TestPlugin).Reset() + + // Initialize the scheduler + scheduler := &Scheduler{ + datastore: &fakeDataStore{pods: test.input}, + preSchedulePlugins: test.preSchedulePlugins, + postSchedulePlugins: test.postSchedulePlugins, + filters: test.filters, + scorers: test.scorers, + picker: test.picker, + } + + req := &types.LLMRequest{Model: "test-model"} + got, err := scheduler.Schedule(context.Background(), req) + + // Validate error state + if test.err != (err != nil) { + t.Fatalf("Unexpected error, got %v, want %v", err, test.err) + } + + if err != nil { + return + } + + // Validate output + opt := cmp.AllowUnexported(types.PodMetrics{}) + wantPod := &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: test.wantTargetPod}, + } + wantPod.SetScore(test.targetPodScore) + wantRes := &types.Result{TargetPod: wantPod} + if diff := cmp.Diff(wantRes, got, opt); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + + // Validate plugin execution counts dynamically + for _, plugin := range test.preSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PreScheduleCallCount != 1 { + t.Errorf("Plugin %s PreSchedule() called %d times, expected 1", tp.NameRes, tp.PreScheduleCallCount) + } + } + + for _, plugin := range test.postSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PostScheduleCallCount != 1 { + t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", tp.NameRes, tp.PostScheduleCallCount) + } + } + + for _, plugin := range test.filters { + tp, _ := plugin.(*TestPlugin) + if tp.FilterCallCount != 1 { + t.Errorf("Plugin %s Filter() called %d times, expected 1", tp.NameRes, tp.FilterCallCount) + } + } + + for _, plugin := range test.scorers { + tp, _ := plugin.(*TestPlugin) + if tp.ScoreCallCount != test.numPodsToScore { + t.Errorf("Plugin %s Score() called %d times, expected 1", tp.NameRes, tp.ScoreCallCount) + } + } + + tp, _ := test.picker.(*TestPlugin) + if tp.PickCallCount != 1 { + t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.NameRes, tp.PickCallCount) + } + + }) + } +} + type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -230,3 +438,68 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { } return pm } + +// TestPlugin is an implementation useful in unit tests. +type TestPlugin struct { + NameRes string + ScoreCallCount int + ScoreRes float64 + ScoreErr error + FilterCallCount int + FilterRes []k8stypes.NamespacedName + FilterErr error + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + PickRes k8stypes.NamespacedName + PickErr error +} + +func (tp *TestPlugin) Name() string { return tp.NameRes } + +func (tp *TestPlugin) Score(ctx *types.Context, pod types.Pod) (float64, error) { + tp.ScoreCallCount++ + return tp.ScoreRes, tp.ScoreErr +} + +func (tp *TestPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + tp.FilterCallCount++ + return findPods(ctx, tp.FilterRes...), tp.FilterErr +} + +func (tp *TestPlugin) PreSchedule(ctx *types.Context) { + tp.PreScheduleCallCount++ +} + +func (tp *TestPlugin) PostSchedule(ctx *types.Context, res *types.Result) { + tp.PostScheduleCallCount++ +} + +func (tp *TestPlugin) Pick(ctx *types.Context, pods []types.Pod) (*types.Result, error) { + tp.PickCallCount++ + if tp.PickErr != nil { + return nil, tp.PickErr + } + pod := findPods(ctx, tp.PickRes)[0] + return &types.Result{TargetPod: pod}, nil +} + +func (tp *TestPlugin) Reset() { + tp.PreScheduleCallCount = 0 + tp.PostScheduleCallCount = 0 + tp.FilterCallCount = 0 + tp.ScoreCallCount = 0 + tp.PickCallCount = 0 +} + +func findPods(ctx *types.Context, names ...k8stypes.NamespacedName) []types.Pod { + res := []types.Pod{} + for _, pod := range ctx.PodsSnapshot { + for _, name := range names { + if pod.GetPod().NamespacedName.String() == name.String() { + res = append(res, pod) + } + } + } + return res +} diff --git a/pkg/epp/scheduling/types/interfaces.go b/pkg/epp/scheduling/types/interfaces.go new file mode 100644 index 000000000..6e954cef0 --- /dev/null +++ b/pkg/epp/scheduling/types/interfaces.go @@ -0,0 +1,75 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" +) + +const ( + PreSchedulerPluginType = "PreSchedule" + PostSchedulePluginType = "PostSchedule" + FilterPluginType = "Filter" + ScorerPluginType = "Scorer" + PickerPluginType = "Picker" +) + +type Pod interface { + GetPod() *backendmetrics.Pod + GetMetrics() *backendmetrics.Metrics + SetScore(float64) + Score() float64 + String() string +} + +// Plugin defines the interface for scheduler plugins, combining scoring, filtering, +// and event handling capabilities. +type Plugin interface { + // Name returns the name of the plugin. + Name() string +} + +// PreSchedule is called when the scheduler receives a new request. It can be used for various +// initialization work. +type PreSchedule interface { + Plugin + PreSchedule(ctx *Context) +} + +// PostSchedule is called by the scheduler after it selects a targetPod for the request. +type PostSchedule interface { + Plugin + PostSchedule(ctx *Context, res *Result) +} + +// Filter defines the interface for filtering a list of pods based on context. +type Filter interface { + Plugin + Filter(ctx *Context, pods []Pod) ([]Pod, error) +} + +// Scorer defines the interface for scoring pods based on context. +type Scorer interface { + Plugin + Score(ctx *Context, pod Pod) (float64, error) +} + +// Picker picks the final pod(s) to send the request to. +type Picker interface { + Plugin + Pick(ctx *Context, pods []Pod) (*Result, error) +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 9450652ed..e52e90472 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -30,23 +30,22 @@ type LLMRequest struct { Model string // Target models is a map of target model name to weight. TargetModels map[string]int + Prompt string // Resolved target model is the final target model after traffic split. ResolvedTargetModel string Critical bool } +func (r *LLMRequest) String() string { + return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, len(r.Prompt)) +} + // Context holds contextual information during a scheduling operation. type Context struct { context.Context Logger logr.Logger Req *LLMRequest - PodsSnapshot []*PodMetrics -} - -type Pod interface { - GetPod() *backendmetrics.Pod - GetMetrics() *backendmetrics.Metrics - String() string + PodsSnapshot []Pod } func (pm *PodMetrics) String() string { @@ -64,12 +63,21 @@ func (pm *PodMetrics) GetMetrics() *backendmetrics.Metrics { return pm.Metrics } +func (pm *PodMetrics) SetScore(score float64) { + pm.score = score +} + +func (pm *PodMetrics) Score() float64 { + return pm.score +} + type PodMetrics struct { + score float64 *backendmetrics.Pod *backendmetrics.Metrics } -func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Context { +func NewContext(ctx context.Context, req *LLMRequest, pods []Pod) *Context { logger := log.FromContext(ctx).WithValues("request", req) return &Context{ Context: ctx, @@ -79,10 +87,15 @@ func NewContext(ctx context.Context, req *LLMRequest, pods []*PodMetrics) *Conte } } -func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []*PodMetrics { - pm := make([]*PodMetrics, 0, len(pods)) +func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { + pm := make([]Pod, 0, len(pods)) for _, pod := range pods { - pm = append(pm, &PodMetrics{pod.GetPod().Clone(), pod.GetMetrics().Clone()}) + pm = append(pm, &PodMetrics{Pod: pod.GetPod().Clone(), Metrics: pod.GetMetrics().Clone()}) } return pm } + +// Result captures the scheduler result. +type Result struct { + TargetPod Pod +} From bf64a9398fe0f76399c6f7d3f605c4ce2d801bf3 Mon Sep 17 00:00:00 2001 From: Nicole Xin Date: Tue, 22 Apr 2025 17:51:42 -0700 Subject: [PATCH 09/20] Complete the InferencePool documentation (#673) * Initial guide for inference pool * Add extensionReference to the InferencePool spec * Fix list formatting * Remove unused labels * Autogenerate the spec * Update site-src/api-types/inferencepool.md Co-authored-by: Rob Scott * Update site-src/api-types/inferencepool.md Co-authored-by: Rob Scott * Update site-src/api-types/inferencepool.md Co-authored-by: Rob Scott * Update site-src/api-types/inferencepool.md Co-authored-by: Rob Scott * Update site-src/api-types/inferencepool.md Co-authored-by: Rob Scott * Update site-src/api-types/inferencepool.md Co-authored-by: Rob Scott * Rename llm-pool names in rollout example * Add use cases for replacing an inference pool * Rewording the background section * Create replacing-inference-pool.md * Replace instructions with a link for how to replace an inference pool * Update replacing-inference-pool.md * Update mkdocs.yml * Update replacing-inference-pool.md * Update inferencemodel_types.go * Update inferencepool.md * Update site-src/guides/replacing-inference-pool.md Co-authored-by: Rob Scott --------- Co-authored-by: Rob Scott --- api/v1alpha2/inferencemodel_types.go | 2 +- mkdocs.yml | 1 + site-src/api-types/inferencepool.md | 58 +++- site-src/guides/replacing-inference-pool.md | 59 ++++ site-src/reference/spec.md | 288 +++++++++++++++++--- 5 files changed, 352 insertions(+), 56 deletions(-) create mode 100644 site-src/guides/replacing-inference-pool.md diff --git a/api/v1alpha2/inferencemodel_types.go b/api/v1alpha2/inferencemodel_types.go index 052683d88..7cd98a740 100644 --- a/api/v1alpha2/inferencemodel_types.go +++ b/api/v1alpha2/inferencemodel_types.go @@ -126,7 +126,7 @@ type PoolObjectReference struct { } // Criticality defines how important it is to serve the model compared to other models. -// Criticality is intentionally a bounded enum to contain the possibilities that need to be supported by the load balancing algorithm. Any reference to the Criticality field must be optional(use a pointer), and set no default. +// Criticality is intentionally a bounded enum to contain the possibilities that need to be supported by the load balancing algorithm. Any reference to the Criticality field must be optional (use a pointer), and set no default. // This allows us to union this with a oneOf field in the future should we wish to adjust/extend this behavior. // +kubebuilder:validation:Enum=Critical;Standard;Sheddable type Criticality string diff --git a/mkdocs.yml b/mkdocs.yml index bdfffe057..e5927ed53 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -63,6 +63,7 @@ nav: - Getting started: guides/index.md - Adapter Rollout: guides/adapter-rollout.md - Metrics: guides/metrics.md + - Replacing an Inference Pool: guides/replacing-inference-pool.md - Implementer's Guide: guides/implementers.md - Performance: - Benchmark: performance/benchmark/index.md diff --git a/site-src/api-types/inferencepool.md b/site-src/api-types/inferencepool.md index baa604b61..1494d314e 100644 --- a/site-src/api-types/inferencepool.md +++ b/site-src/api-types/inferencepool.md @@ -7,28 +7,56 @@ ## Background -The InferencePool resource is a logical grouping of compute resources, e.g. Pods, that run model servers. The InferencePool would deploy its own routing, and offer administrative configuration to the Platform Admin. +The **InferencePool** API defines a group of Pods (containers) dedicated to serving AI models. Pods within an InferencePool share the same compute configuration, accelerator type, base language model, and model server. This abstraction simplifies the management of AI model serving resources, providing a centralized point of administrative configuration for Platform Admins. -It is expected for the InferencePool to: +An InferencePool is expected to be bundled with an [Endpoint Picker](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp) extension. This extension is responsible for tracking key metrics on each model server (i.e. the KV-cache utilization, queue length of pending requests, active LoRA adapters, etc.) and routing incoming inference requests to the optimal model server replica based on these metrics. An EPP can only be associated with a single InferencePool. The associated InferencePool is specified by the [poolName](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/config/manifests/inferencepool-resources.yaml#L54) and [poolNamespace](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/config/manifests/inferencepool-resources.yaml#L56) flags. An HTTPRoute can have multiple backendRefs that reference the same InferencePool and therefore routes to the same EPP. An HTTPRoute can have multiple backendRefs that reference different InferencePools and therefore routes to different EPPs. - - Enforce fair consumption of resources across competing workloads - - Efficiently route requests across shared compute (as displayed by the PoC) - -It is _not_ expected for the InferencePool to: +Additionally, any Pod that seeks to join an InferencePool would need to support the [model server protocol](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/003-model-server-protocol), defined by this project, to ensure the Endpoint Picker has adequate information to intelligently route requests. - - Enforce any common set of adapters or base models are available on the Pods - - Manage Deployments of Pods within the Pool - - Manage Pod lifecycle of pods within the pool +## How to Configure an InferencePool -Additionally, any Pod that seeks to join an InferencePool would need to support a protocol, defined by this project, to ensure the Pool has adequate information to intelligently route requests. +The full spec of the InferencePool is defined [here](/reference/spec/#inferencepool). -`InferencePool` has some small overlap with `Service`, displayed here: +In summary, the InferencePoolSpec consists of 3 major parts: + +- The `selector` field specifies which Pods belong to this pool. The labels in this selector must exactly match the labels applied to your model server Pods. +- The `targetPortNumber` field defines the port number that the Inference Gateway should route to on model server Pods that belong to this pool. +- The `extensionRef` field references the [endpoint picker extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp) (EPP) service that monitors key metrics from model servers within the InferencePool and provides intelligent routing decisions. + +### Example Configuration + +Here is an example InferencePool configuration: + +``` +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferencePool +metadata: + name: vllm-llama3-8b-instruct +spec: + targetPortNumber: 8000 + selector: + app: vllm-llama3-8b-instruct + extensionRef: + name: vllm-llama3-8b-instruct-epp + port: 9002 + failureMode: FailClose +``` + +In this example: + +- An InferencePool named `vllm-llama3-8b-instruct` is created in the `default` namespace. +- It will select Pods that have the label `app: vllm-llama3-8b-instruct`. +- Traffic routed to this InferencePool will call out to the EPP service `vllm-llama3-8b-instruct-epp` on port `9002` for making routing decisions. If EPP fails to pick an endpoint, or is not responsive, the request will be dropped. +- Traffic routed to this InferencePool will be forwarded to the port `8000` on the selected Pods. + +## Overlap with Service + +**InferencePool** has some small overlap with **Service**, displayed here: Comparing InferencePool with Service -The InferencePool is _not_ intended to be a mask of the Service object, simply exposing the absolute bare minimum required to allow the Platform Admin to focus less on networking, and more on Pool management. - -## Spec +The InferencePool is not intended to be a mask of the Service object. It provides a specialized abstraction tailored for managing and routing traffic to groups of LLM model servers, allowing Platform Admins to focus on pool-level management rather than low-level networking details. -The full spec of the InferencePool is defined [here](/reference/spec/#inferencepool). \ No newline at end of file +## Replacing an InferencePool +Please refer to the [Replacing an InferencePool](/guides/replacing-inference-pool) guide for details on uses cases and how to replace an InferencePool. diff --git a/site-src/guides/replacing-inference-pool.md b/site-src/guides/replacing-inference-pool.md new file mode 100644 index 000000000..212945706 --- /dev/null +++ b/site-src/guides/replacing-inference-pool.md @@ -0,0 +1,59 @@ +# Replacing an InferencePool + +## Background + +Replacing an InferencePool is a powerful technique for performing various infrastructure and model updates with minimal disruption and built-in rollback capabilities. This method allows you to introduce changes incrementally, monitor their impact, and revert to the previous state if necessary. + +## Use Cases +Use Cases for Replacing an InferencePool: + +- Upgrading or replacing your model server framework +- Upgrading or replacing your base model +- Transitioning to new hardware + +## How to replace an InferencePool + +To replacing an InferencePool: + +1. **Deploy new infrastructure**: Create a new InferencePool configured with the new hardware / model server / base model that you chose. +1. **Configure traffic splitting**: Use an HTTPRoute to split traffic between the existing InferencePool and the new InferencePool. The `backendRefs.weight` field controls the traffic percentage allocated to each pool. +1. **Maintain InferenceModel integrity**: Keep your InferenceModel configuration unchanged. This ensures that the system applies the same LoRA adapters consistently across both base model versions. +1. **Preserve rollback capability**: Retain the original nodes and InferencePool during the roll out to facilitate a rollback if necessary. + +### Example + +You start with an existing lnferencePool named `llm-pool-v1`. To replace the original InferencePool, you create a new InferencePool named `llm-pool-v2`. By configuring an **HTTPRoute**, as shown below, you can incrementally split traffic between the original `llm-pool-v1` and new `llm-pool-v2`. + +1. Save the following sample manifest as `httproute.yaml`: + + ```yaml + apiVersion: gateway.networking.k8s.io/v1 + kind: HTTPRoute + metadata: + name: llm-route + spec: + parentRefs: + - group: gateway.networking.k8s.io + kind: Gateway + name: inference-gateway + rules: + backendRefs: + - group: inference.networking.x-k8s.io + kind: InferencePool + name: llm-pool-v1 + weight: 90 + - group: inference.networking.x-k8s.io + kind: InferencePool + name: llm-pool-v2 + weight: 10 + ``` + +1. Apply the sample manifest to your cluster: + + ``` + kubectl apply -f httproute.yaml + ``` + + The original `llm-pool-v1` InferencePool receives most of the traffic, while the `llm-pool-v2` InferencePool receives the rest. + +1. Increase the traffic weight gradually for the `llm-pool-v2` InferencePool to complete the new InferencePool roll out. diff --git a/site-src/reference/spec.md b/site-src/reference/spec.md index e16c113c1..d8e0c95bf 100644 --- a/site-src/reference/spec.md +++ b/site-src/reference/spec.md @@ -1,12 +1,14 @@ # API Reference ## Packages -- [inference.networking.x-k8s.io/v1alpha1](#inferencenetworkingx-k8siov1alpha1) +- [inference.networking.x-k8s.io/v1alpha2](#inferencenetworkingx-k8siov1alpha2) -## inference.networking.x-k8s.io/v1alpha1 +## inference.networking.x-k8s.io/v1alpha2 + +Package v1alpha2 contains API Schema definitions for the +inference.networking.x-k8s.io API group. -Package v1alpha1 contains API Schema definitions for the gateway v1alpha1 API group ### Resource Types - [InferenceModel](#inferencemodel) @@ -18,26 +20,152 @@ Package v1alpha1 contains API Schema definitions for the gateway v1alpha1 API gr _Underlying type:_ _string_ -Defines how important it is to serve the model compared to other models. +Criticality defines how important it is to serve the model compared to other models. +Criticality is intentionally a bounded enum to contain the possibilities that need to be supported by the load balancing algorithm. Any reference to the Criticality field must be optional(use a pointer), and set no default. +This allows us to union this with a oneOf field in the future should we wish to adjust/extend this behavior. _Validation:_ -- Enum: [Critical Default Sheddable] +- Enum: [Critical Standard Sheddable] _Appears in:_ - [InferenceModelSpec](#inferencemodelspec) | Field | Description | | --- | --- | -| `Critical` | Most important. Requests to this band will be shed last.
| -| `Default` | More important than Sheddable, less important than Critical.
Requests in this band will be shed before critical traffic.
+kubebuilder:default=Default
| -| `Sheddable` | Least important. Requests to this band will be shed before all other bands.
| +| `Critical` | Critical defines the highest level of criticality. Requests to this band will be shed last.
| +| `Standard` | Standard defines the base criticality level and is more important than Sheddable but less
important than Critical. Requests in this band will be shed before critical traffic.
Most models are expected to fall within this band.
| +| `Sheddable` | Sheddable defines the lowest level of criticality. Requests to this band will be shed before
all other bands.
| + + +#### EndpointPickerConfig + + + +EndpointPickerConfig specifies the configuration needed by the proxy to discover and connect to the endpoint picker extension. +This type is intended to be a union of mutually exclusive configuration options that we may add in the future. + + + +_Appears in:_ +- [InferencePoolSpec](#inferencepoolspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `extensionRef` _[Extension](#extension)_ | Extension configures an endpoint picker as an extension service. | | Required: \{\}
| + + +#### Extension + + + +Extension specifies how to configure an extension that runs the endpoint picker. + + + +_Appears in:_ +- [EndpointPickerConfig](#endpointpickerconfig) +- [InferencePoolSpec](#inferencepoolspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `group` _[Group](#group)_ | Group is the group of the referent.
The default value is "", representing the Core API group. | | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| +| `kind` _[Kind](#kind)_ | Kind is the Kubernetes resource kind of the referent. For example
"Service".
Defaults to "Service" when not specified.
ExternalName services can refer to CNAME DNS records that may live
outside of the cluster and as such are difficult to reason about in
terms of conformance. They also may not be safe to forward to (see
CVE-2021-25740 for more information). Implementations MUST NOT
support ExternalName Services. | Service | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| +| `name` _[ObjectName](#objectname)_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| +| `portNumber` _[PortNumber](#portnumber)_ | The port number on the service running the extension. When unspecified,
implementations SHOULD infer a default value of 9002 when the Kind is
Service. | | Maximum: 65535
Minimum: 1
| +| `failureMode` _[ExtensionFailureMode](#extensionfailuremode)_ | Configures how the gateway handles the case when the extension is not responsive.
Defaults to failClose. | FailClose | Enum: [FailOpen FailClose]
| + + +#### ExtensionConnection + + + +ExtensionConnection encapsulates options that configures the connection to the extension. + + + +_Appears in:_ +- [Extension](#extension) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `failureMode` _[ExtensionFailureMode](#extensionfailuremode)_ | Configures how the gateway handles the case when the extension is not responsive.
Defaults to failClose. | FailClose | Enum: [FailOpen FailClose]
| + + +#### ExtensionFailureMode + +_Underlying type:_ _string_ + +ExtensionFailureMode defines the options for how the gateway handles the case when the extension is not +responsive. + +_Validation:_ +- Enum: [FailOpen FailClose] + +_Appears in:_ +- [Extension](#extension) +- [ExtensionConnection](#extensionconnection) + +| Field | Description | +| --- | --- | +| `FailOpen` | FailOpen specifies that the proxy should not drop the request and forward the request to and endpoint of its picking.
| +| `FailClose` | FailClose specifies that the proxy should drop the request.
| + + +#### ExtensionReference + + + +ExtensionReference is a reference to the extension deployment. + + + +_Appears in:_ +- [Extension](#extension) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `group` _[Group](#group)_ | Group is the group of the referent.
The default value is "", representing the Core API group. | | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| +| `kind` _[Kind](#kind)_ | Kind is the Kubernetes resource kind of the referent. For example
"Service".
Defaults to "Service" when not specified.
ExternalName services can refer to CNAME DNS records that may live
outside of the cluster and as such are difficult to reason about in
terms of conformance. They also may not be safe to forward to (see
CVE-2021-25740 for more information). Implementations MUST NOT
support ExternalName Services. | Service | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| +| `name` _[ObjectName](#objectname)_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| +| `portNumber` _[PortNumber](#portnumber)_ | The port number on the service running the extension. When unspecified,
implementations SHOULD infer a default value of 9002 when the Kind is
Service. | | Maximum: 65535
Minimum: 1
| + + +#### Group + +_Underlying type:_ _string_ + +Group refers to a Kubernetes Group. It must either be an empty string or a +RFC 1123 subdomain. + +This validation is based off of the corresponding Kubernetes validation: +https://github.com/kubernetes/apimachinery/blob/02cfb53916346d085a6c6c7c66f882e3c6b0eca6/pkg/util/validation/validation.go#L208 + +Valid values include: + +* "" - empty string implies core Kubernetes API group +* "gateway.networking.k8s.io" +* "foo.example.com" + +Invalid values include: + +* "example.com/bar" - "/" is an invalid character + +_Validation:_ +- MaxLength: 253 +- Pattern: `^$|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$` + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) +- [PoolObjectReference](#poolobjectreference) + #### InferenceModel -InferenceModel is the Schema for the InferenceModels API +InferenceModel is the Schema for the InferenceModels API. @@ -45,29 +173,31 @@ InferenceModel is the Schema for the InferenceModels API | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha1` | | | +| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha2` | | | | `kind` _string_ | `InferenceModel` | | | | `metadata` _[ObjectMeta](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#objectmeta-v1-meta)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `spec` _[InferenceModelSpec](#inferencemodelspec)_ | | | | | `status` _[InferenceModelStatus](#inferencemodelstatus)_ | | | | + + + + #### InferenceModelSpec -InferenceModelSpec represents a specific model use case. This resource is +InferenceModelSpec represents the desired state of a specific model use case. This resource is managed by the "Inference Workload Owner" persona. - -The Inference Workload Owner persona is: a team that trains, verifies, and +The Inference Workload Owner persona is someone that trains, verifies, and leverages a large language model from a model frontend, drives the lifecycle and rollout of new versions of those models, and defines the specific performance and latency goals for the model. These workloads are expected to operate within an InferencePool sharing compute capacity with other InferenceModels, defined by the Inference Platform Admin. - InferenceModel's modelName (not the ObjectMeta name) is unique for a given InferencePool, if the name is reused, an error will be shown on the status of a InferenceModel that attempted to reuse. The oldest InferenceModel, based on @@ -81,10 +211,10 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `modelName` _string_ | The name of the model as the users set in the "model" parameter in the requests.
The name should be unique among the workloads that reference the same backend pool.
This is the parameter that will be used to match the request with. In the future, we may
allow to match on other request parameters. The other approach to support matching on
on other request parameters is to use a different ModelName per HTTPFilter.
Names can be reserved without implementing an actual model in the pool.
This can be done by specifying a target model and setting the weight to zero,
an error will be returned specifying that no valid target model is found. | | MaxLength: 253
| -| `criticality` _[Criticality](#criticality)_ | Defines how important it is to serve the model compared to other models referencing the same pool. | Default | Enum: [Critical Default Sheddable]
| -| `targetModels` _[TargetModel](#targetmodel) array_ | Allow multiple versions of a model for traffic splitting.
If not specified, the target model name is defaulted to the modelName parameter.
modelName is often in reference to a LoRA adapter. | | MaxItems: 10
| -| `poolRef` _[PoolObjectReference](#poolobjectreference)_ | Reference to the inference pool, the pool must exist in the same namespace. | | Required: \{\}
| +| `modelName` _string_ | ModelName is the name of the model as it will be set in the "model" parameter for an incoming request.
ModelNames must be unique for a referencing InferencePool
(names can be reused for a different pool in the same cluster).
The modelName with the oldest creation timestamp is retained, and the incoming
InferenceModel is sets the Ready status to false with a corresponding reason.
In the rare case of a race condition, one Model will be selected randomly to be considered valid, and the other rejected.
Names can be reserved without an underlying model configured in the pool.
This can be done by specifying a target model and setting the weight to zero,
an error will be returned specifying that no valid target model is found. | | MaxLength: 256
Required: \{\}
| +| `criticality` _[Criticality](#criticality)_ | Criticality defines how important it is to serve the model compared to other models referencing the same pool.
Criticality impacts how traffic is handled in resource constrained situations. It handles this by
queuing or rejecting requests of lower criticality. InferenceModels of an equivalent Criticality will
fairly share resources over throughput of tokens. In the future, the metric used to calculate fairness,
and the proportionality of fairness will be configurable.
Default values for this field will not be set, to allow for future additions of new field that may 'one of' with this field.
Any implementations that may consume this field may treat an unset value as the 'Standard' range. | | Enum: [Critical Standard Sheddable]
| +| `targetModels` _[TargetModel](#targetmodel) array_ | TargetModels allow multiple versions of a model for traffic splitting.
If not specified, the target model name is defaulted to the modelName parameter.
modelName is often in reference to a LoRA adapter. | | MaxItems: 10
| +| `poolRef` _[PoolObjectReference](#poolobjectreference)_ | PoolRef is a reference to the inference pool, the pool must exist in the same namespace. | | Required: \{\}
| #### InferenceModelStatus @@ -100,14 +230,14 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferencePool. | | | +| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferenceModel.
Known condition types are:
* "Accepted" | [map[lastTransitionTime:1970-01-01T00:00:00Z message:Waiting for controller reason:Pending status:Unknown type:Ready]] | MaxItems: 8
| #### InferencePool -InferencePool is the Schema for the Inferencepools API +InferencePool is the Schema for the InferencePools API. @@ -115,13 +245,17 @@ InferencePool is the Schema for the Inferencepools API | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha1` | | | +| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha2` | | | | `kind` _string_ | `InferencePool` | | | | `metadata` _[ObjectMeta](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#objectmeta-v1-meta)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `spec` _[InferencePoolSpec](#inferencepoolspec)_ | | | | | `status` _[InferencePoolStatus](#inferencepoolstatus)_ | | | | + + + + #### InferencePoolSpec @@ -135,8 +269,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `selector` _object (keys:[LabelKey](#labelkey), values:[LabelValue](#labelvalue))_ | Selector uses a map of label to watch model server pods
that should be included in the InferencePool. ModelServers should not
be with any other Service or InferencePool, that behavior is not supported
and will result in sub-optimal utilization.
In some cases, implementations may translate this to a Service selector, so this matches the simple
map used for Service selectors instead of the full Kubernetes LabelSelector type. | | Required: \{\}
| -| `targetPortNumber` _integer_ | TargetPortNumber is the port number that the model servers within the pool expect
to receive traffic from.
This maps to the TargetPort in: https://pkg.go.dev/k8s.io/api/core/v1#ServicePort | | Maximum: 65535
Minimum: 0
Required: \{\}
| +| `selector` _object (keys:[LabelKey](#labelkey), values:[LabelValue](#labelvalue))_ | Selector defines a map of labels to watch model server pods
that should be included in the InferencePool.
In some cases, implementations may translate this field to a Service selector, so this matches the simple
map used for Service selectors instead of the full Kubernetes LabelSelector type.
If sepecified, it will be applied to match the model server pods in the same namespace as the InferencePool.
Cross namesoace selector is not supported. | | Required: \{\}
| +| `targetPortNumber` _integer_ | TargetPortNumber defines the port number to access the selected model servers.
The number must be in the range 1 to 65535. | | Maximum: 65535
Minimum: 1
Required: \{\}
| +| `extensionRef` _[Extension](#extension)_ | Extension configures an endpoint picker as an extension service. | | Required: \{\}
| #### InferencePoolStatus @@ -152,33 +287,56 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferencePool. | | | +| `parent` _[PoolStatus](#poolstatus) array_ | Parents is a list of parent resources (usually Gateways) that are
associated with the route, and the status of the InferencePool with respect to
each parent.
A maximum of 32 Gateways will be represented in this list. An empty list
means the route has not been attached to any Gateway. | | MaxItems: 32
| + + +#### Kind + +_Underlying type:_ _string_ + +Kind refers to a Kubernetes Kind. + +Valid values include: + +* "Service" +* "HTTPRoute" + +Invalid values include: + +* "invalid/kind" - "/" is an invalid character + +_Validation:_ +- MaxLength: 63 +- MinLength: 1 +- Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$` + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) +- [PoolObjectReference](#poolobjectreference) + #### LabelKey _Underlying type:_ _string_ -Originally copied from: https://github.com/kubernetes-sigs/gateway-api/blob/99a3934c6bc1ce0874f3a4c5f20cafd8977ffcb4/apis/v1/shared_types.go#L694-L731 +LabelKey was originally copied from: https://github.com/kubernetes-sigs/gateway-api/blob/99a3934c6bc1ce0874f3a4c5f20cafd8977ffcb4/apis/v1/shared_types.go#L694-L731 Duplicated as to not take an unexpected dependency on gw's API. - LabelKey is the key of a label. This is used for validation of maps. This matches the Kubernetes "qualified name" validation that is used for labels. - +Labels are case sensitive, so: my-label and My-Label are considered distinct. Valid values include: - * example * example.com * example.com/path * example.com/path.html - Invalid values include: - * example~ - "~" is an invalid character * example.com. - can not start or end with "." @@ -202,10 +360,8 @@ of maps. This matches the Kubernetes label validation rules: * unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]), * could contain dashes (-), underscores (_), dots (.), and alphanumerics between. - Valid values include: - * MyValue * my.name * 123-my-value @@ -220,6 +376,25 @@ _Appears in:_ +#### ObjectName + +_Underlying type:_ _string_ + +ObjectName refers to the name of a Kubernetes object. +Object names can have a variety of forms, including RFC 1123 subdomains, +RFC 1123 labels, or RFC 1035 labels. + +_Validation:_ +- MaxLength: 253 +- MinLength: 1 + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) +- [PoolObjectReference](#poolobjectreference) + + + #### PoolObjectReference @@ -234,9 +409,42 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `group` _string_ | Group is the group of the referent. | inference.networking.x-k8s.io | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| -| `kind` _string_ | Kind is kind of the referent. For example "InferencePool". | InferencePool | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| -| `name` _string_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| +| `group` _[Group](#group)_ | Group is the group of the referent. | inference.networking.x-k8s.io | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| +| `kind` _[Kind](#kind)_ | Kind is kind of the referent. For example "InferencePool". | InferencePool | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| +| `name` _[ObjectName](#objectname)_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| + + +#### PoolStatus + + + +PoolStatus defines the observed state of InferencePool from a Gateway. + + + +_Appears in:_ +- [InferencePoolStatus](#inferencepoolstatus) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `parentRef` _[ObjectReference](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#objectreference-v1-core)_ | GatewayRef indicates the gateway that observed state of InferencePool. | | | +| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferencePool.
Known condition types are:
* "Accepted"
* "ResolvedRefs" | [map[lastTransitionTime:1970-01-01T00:00:00Z message:Waiting for controller reason:Pending status:Unknown type:Accepted]] | MaxItems: 8
| + + +#### PortNumber + +_Underlying type:_ _integer_ + +PortNumber defines a network port. + +_Validation:_ +- Maximum: 65535 +- Minimum: 1 + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) + #### TargetModel @@ -246,10 +454,10 @@ _Appears in:_ TargetModel represents a deployed model or a LoRA adapter. The Name field is expected to match the name of the LoRA adapter (or base model) as it is registered within the model server. Inference -Gateway assumes that the model exists on the model server and is the +Gateway assumes that the model exists on the model server and it's the responsibility of the user to validate a correct match. Should a model fail -to exist at request time, the error is processed by the Instance Gateway, -and then emitted on the appropriate InferenceModel object. +to exist at request time, the error is processed by the Inference Gateway +and emitted on the appropriate InferenceModel object. @@ -258,7 +466,7 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `name` _string_ | The name of the adapter as expected by the ModelServer. | | MaxLength: 253
| -| `weight` _integer_ | Weight is used to determine the proportion of traffic that should be
sent to this target model when multiple versions of the model are specified. | 1 | Maximum: 1e+06
Minimum: 0
| +| `name` _string_ | Name is the name of the adapter or base model, as expected by the ModelServer. | | MaxLength: 253
Required: \{\}
| +| `weight` _integer_ | Weight is used to determine the proportion of traffic that should be
sent to this model when multiple target models are specified.
Weight defines the proportion of requests forwarded to the specified
model. This is computed as weight/(sum of all weights in this
TargetModels list). For non-zero values, there may be some epsilon from
the exact proportion defined here depending on the precision an
implementation supports. Weight is not a percentage and the sum of
weights does not need to equal 100.
If a weight is set for any targetModel, it must be set for all targetModels.
Conversely weights are optional, so long as ALL targetModels do not specify a weight. | | Maximum: 1e+06
Minimum: 1
| From 76a562f767bcf93c0b2a281faf5b2a7689255ba5 Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Wed, 23 Apr 2025 04:27:40 +0300 Subject: [PATCH 10/20] reduce log level in metrics logger not to trash the log (#708) * reduce log level in metrics logger not to trash the log Signed-off-by: Nir Rozenbaum * rename flush metrics to refresh metrics Signed-off-by: Nir Rozenbaum * revert log level Signed-off-by: Nir Rozenbaum --------- Signed-off-by: Nir Rozenbaum --- cmd/epp/main.go | 9 ++++----- pkg/epp/backend/metrics/logger.go | 10 +++++----- pkg/epp/server/runserver.go | 17 ++++++----------- test/integration/epp/hermetic_test.go | 2 +- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index b5e6fbe6b..c0a87e62e 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -142,8 +142,8 @@ func run() error { } poolNamespacedName := types.NamespacedName{ - Namespace: *poolNamespace, Name: *poolName, + Namespace: *poolNamespace, } mgr, err := runserver.NewDefaultManager(poolNamespacedName, cfg) if err != nil { @@ -151,8 +151,6 @@ func run() error { return err } - ctx := ctrl.SetupSignalHandler() - // Set up mapper for metric scraping. mapping, err := backendmetrics.NewMetricMapping( *totalQueuedRequestsMetric, @@ -167,14 +165,15 @@ func run() error { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{MetricMapping: mapping}, *refreshMetricsInterval) // Setup runner. + ctx := ctrl.SetupSignalHandler() + datastore := datastore.NewDatastore(ctx, pmf) serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, DestinationEndpointHintKey: *destinationEndpointHintKey, - PoolName: *poolName, - PoolNamespace: *poolNamespace, + PoolNamespacedName: poolNamespacedName, Datastore: datastore, SecureServing: *secureServing, CertPath: *certPath, diff --git a/pkg/epp/backend/metrics/logger.go b/pkg/epp/backend/metrics/logger.go index d9a930277..7dc1a8b8b 100644 --- a/pkg/epp/backend/metrics/logger.go +++ b/pkg/epp/backend/metrics/logger.go @@ -55,8 +55,8 @@ func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometh case <-ctx.Done(): logger.V(logutil.DEFAULT).Info("Shutting down prometheus metrics thread") return - case <-ticker.C: // Periodically flush prometheus metrics for inference pool - flushPrometheusMetricsOnce(logger, datastore) + case <-ticker.C: // Periodically refresh prometheus metrics for inference pool + refreshPrometheusMetrics(logger, datastore) } } }() @@ -86,11 +86,11 @@ func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometh } } -func flushPrometheusMetricsOnce(logger logr.Logger, datastore Datastore) { +func refreshPrometheusMetrics(logger logr.Logger, datastore Datastore) { pool, err := datastore.PoolGet() if err != nil { // No inference pool or not initialize. - logger.V(logutil.DEFAULT).Info("pool is not initialized, skipping flushing metrics") + logger.V(logutil.DEFAULT).Info("Pool is not initialized, skipping refreshing metrics") return } @@ -98,7 +98,7 @@ func flushPrometheusMetricsOnce(logger logr.Logger, datastore Datastore) { var queueTotal int podMetrics := datastore.PodGetAll() - logger.V(logutil.VERBOSE).Info("Flushing Prometheus Metrics", "ReadyPods", len(podMetrics)) + logger.V(logutil.TRACE).Info("Refreshing Prometheus Metrics", "ReadyPods", len(podMetrics)) if len(podMetrics) == 0 { return } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 65a6e7879..0c0a6a6dc 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -43,8 +43,7 @@ type ExtProcServerRunner struct { GrpcPort int DestinationEndpointHintMetadataNamespace string DestinationEndpointHintKey string - PoolName string - PoolNamespace string + PoolNamespacedName types.NamespacedName Datastore datastore.Datastore SecureServing bool CertPath string @@ -73,8 +72,7 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { GrpcPort: DefaultGrpcPort, DestinationEndpointHintKey: DefaultDestinationEndpointHintKey, DestinationEndpointHintMetadataNamespace: DefaultDestinationEndpointHintMetadataNamespace, - PoolName: DefaultPoolName, - PoolNamespace: DefaultPoolNamespace, + PoolNamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, SecureServing: DefaultSecureServing, RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, // Datastore can be assigned later. @@ -93,13 +91,10 @@ func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Man } if err := (&controller.InferenceModelReconciler{ - Datastore: r.Datastore, - Client: mgr.GetClient(), - PoolNamespacedName: types.NamespacedName{ - Name: r.PoolName, - Namespace: r.PoolNamespace, - }, - Record: mgr.GetEventRecorderFor("InferenceModel"), + Datastore: r.Datastore, + Client: mgr.GetClient(), + PoolNamespacedName: r.PoolNamespacedName, + Record: mgr.GetEventRecorderFor("InferenceModel"), }).SetupWithManager(ctx, mgr); err != nil { return fmt.Errorf("failed setting up InferenceModelReconciler: %w", err) } diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 372158f4b..79b619fd6 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -1348,7 +1348,7 @@ func BeforeSuite() func() { serverRunner.TestPodMetricsClient = &backendmetrics.FakePodMetricsClient{} pmf := backendmetrics.NewPodMetricsFactory(serverRunner.TestPodMetricsClient, 10*time.Millisecond) // Adjust from defaults - serverRunner.PoolName = "vllm-llama3-8b-instruct-pool" + serverRunner.PoolNamespacedName = types.NamespacedName{Name: "vllm-llama3-8b-instruct-pool", Namespace: "default"} serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) serverRunner.SecureServing = false From 7792676a2bc23092857617d475631dedecbb4edb Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Wed, 23 Apr 2025 04:27:47 +0300 Subject: [PATCH 11/20] few updates in datastore (#713) * few updates in datastore Signed-off-by: Nir Rozenbaum * PoolSet documentation Signed-off-by: Nir Rozenbaum * error phrasing Signed-off-by: Nir Rozenbaum * removed unused pool arg from PodUpdateOrAddIfNotExist Signed-off-by: Nir Rozenbaum * linter Signed-off-by: Nir Rozenbaum --------- Signed-off-by: Nir Rozenbaum --- .../inferencemodel_reconciler_test.go | 5 +- .../controller/inferencepool_reconciler.go | 24 ++---- pkg/epp/controller/pod_reconciler.go | 12 ++- pkg/epp/controller/pod_reconciler_test.go | 4 +- pkg/epp/datastore/datastore.go | 78 ++++++++++++------- pkg/epp/datastore/datastore_test.go | 20 ++++- pkg/epp/util/pod/pod.go | 3 + 7 files changed, 88 insertions(+), 58 deletions(-) diff --git a/pkg/epp/controller/inferencemodel_reconciler_test.go b/pkg/epp/controller/inferencemodel_reconciler_test.go index 57dc2469b..80c30e191 100644 --- a/pkg/epp/controller/inferencemodel_reconciler_test.go +++ b/pkg/epp/controller/inferencemodel_reconciler_test.go @@ -25,6 +25,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -178,6 +179,7 @@ func TestInferenceModelReconciler(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Create a fake client with no InferenceModel objects. scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) _ = v1alpha2.Install(scheme) initObjs := []client.Object{} if test.model != nil { @@ -186,6 +188,7 @@ func TestInferenceModelReconciler(t *testing.T) { for _, m := range test.modelsInAPIServer { initObjs = append(initObjs, m) } + fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(initObjs...). @@ -196,7 +199,7 @@ func TestInferenceModelReconciler(t *testing.T) { for _, m := range test.modelsInStore { ds.ModelSetIfOlder(m) } - ds.PoolSet(pool) + _ = ds.PoolSet(context.Background(), fakeClient, pool) reconciler := &InferenceModelReconciler{ Client: fakeClient, Record: record.NewFakeRecorder(10), diff --git a/pkg/epp/controller/inferencepool_reconciler.go b/pkg/epp/controller/inferencepool_reconciler.go index 0738181f0..fb7d77273 100644 --- a/pkg/epp/controller/inferencepool_reconciler.go +++ b/pkg/epp/controller/inferencepool_reconciler.go @@ -18,7 +18,6 @@ package controller import ( "context" - "reflect" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/tools/record" @@ -60,28 +59,15 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques c.Datastore.Clear() return ctrl.Result{}, nil } - - c.updateDatastore(ctx, infPool) + // update pool in datastore + if err := c.Datastore.PoolSet(ctx, c.Client, infPool); err != nil { + logger.Error(err, "Failed to update datastore") + return ctrl.Result{}, err + } return ctrl.Result{}, nil } -func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool *v1alpha2.InferencePool) { - logger := log.FromContext(ctx) - oldPool, err := c.Datastore.PoolGet() - c.Datastore.PoolSet(newPool) - if err != nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { - logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", newPool.Spec.Selector) - // A full resync is required to address two cases: - // 1) At startup, the pod events may get processed before the pool is synced with the datastore, - // and hence they will not be added to the store since pool selector is not known yet - // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need - // to resync the whole pool: remove pods in the store that don't match the new selector and add - // the ones that may have existed already to the store. - c.Datastore.PodResyncAll(ctx, c.Client, newPool) - } -} - func (c *InferencePoolReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). For(&v1alpha2.InferencePool{}). diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 494adeb79..6d1af8d9a 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -27,7 +27,6 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" @@ -41,8 +40,7 @@ type PodReconciler struct { func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { logger := log.FromContext(ctx) - pool, err := c.Datastore.PoolGet() - if err != nil { + if !c.Datastore.PoolHasSynced() { logger.V(logutil.TRACE).Info("Skipping reconciling Pod because the InferencePool is not available yet") // When the inferencePool is initialized it lists the appropriate pods and populates the datastore, so no need to requeue. return ctrl.Result{}, nil @@ -60,7 +58,7 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R return ctrl.Result{}, err } - c.updateDatastore(logger, pod, pool) + c.updateDatastore(logger, pod) return ctrl.Result{}, nil } @@ -70,13 +68,13 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(c) } -func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod, pool *v1alpha2.InferencePool) { +func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} - if !pod.DeletionTimestamp.IsZero() || !c.Datastore.PoolLabelsMatch(pod.Labels) || !podutil.IsPodReady(pod) { + if !podutil.IsPodReady(pod) || !c.Datastore.PoolLabelsMatch(pod.Labels) { logger.V(logutil.DEBUG).Info("Pod removed or not added", "name", namespacedName) c.Datastore.PodDelete(namespacedName) } else { - if c.Datastore.PodUpdateOrAddIfNotExist(pod, pool) { + if c.Datastore.PodUpdateOrAddIfNotExist(pod) { logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) } else { logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index e4cb0b62d..d2bdd5d09 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -182,9 +182,9 @@ func TestPodReconciler(t *testing.T) { // Configure the initial state of the datastore. store := datastore.NewDatastore(t.Context(), pmf) - store.PoolSet(test.pool) + _ = store.PoolSet(t.Context(), fakeClient, test.pool) for _, pod := range test.existingPods { - store.PodUpdateOrAddIfNotExist(pod, pool) + store.PodUpdateOrAddIfNotExist(pod) } podReconciler := &PodReconciler{Client: fakeClient, Datastore: store} diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 5435e3af8..f8378d250 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "reflect" "sync" corev1 "k8s.io/api/core/v1" @@ -44,7 +45,10 @@ var ( // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) type Datastore interface { // InferencePool operations - PoolSet(pool *v1alpha2.InferencePool) + // PoolSet sets the given pool in datastore. If the given pool has different label selector than the previous pool + // that was stored, the function triggers a resync of the pods to keep the datastore updated. If the given pool + // is nil, this call triggers the datastore.Clear() function. + PoolSet(ctx context.Context, client client.Client, pool *v1alpha2.InferencePool) error PoolGet() (*v1alpha2.InferencePool, error) PoolHasSynced() bool PoolLabelsMatch(podLabels map[string]string) bool @@ -60,10 +64,9 @@ type Datastore interface { // PodGetAll returns all pods and metrics, including fresh and stale. PodGetAll() []backendmetrics.PodMetrics // PodList lists pods matching the given predicate. - PodList(func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics - PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool + PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool PodDelete(namespacedName types.NamespacedName) - PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) // Clears the store state, happens when the pool gets deleted. Clear() @@ -102,10 +105,31 @@ func (ds *datastore) Clear() { } // /// InferencePool APIs /// -func (ds *datastore) PoolSet(pool *v1alpha2.InferencePool) { +func (ds *datastore) PoolSet(ctx context.Context, client client.Client, pool *v1alpha2.InferencePool) error { + if pool == nil { + ds.Clear() + return nil + } + logger := log.FromContext(ctx) ds.poolAndModelsMu.Lock() defer ds.poolAndModelsMu.Unlock() + + oldPool := ds.pool ds.pool = pool + if oldPool == nil || !reflect.DeepEqual(pool.Spec.Selector, oldPool.Spec.Selector) { + logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", pool.Spec.Selector) + // A full resync is required to address two cases: + // 1) At startup, the pod events may get processed before the pool is synced with the datastore, + // and hence they will not be added to the store since pool selector is not known yet + // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need + // to resync the whole pool: remove pods in the store that don't match the new selector and add + // the ones that may have existed already to the store. + if err := ds.podResyncAll(ctx, client); err != nil { + return fmt.Errorf("failed to update pods according to the pool selector - %w", err) + } + } + + return nil } func (ds *datastore) PoolGet() (*v1alpha2.InferencePool, error) { @@ -229,7 +253,7 @@ func (ds *datastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []b return res } -func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool { +func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { namespacedName := types.NamespacedName{ Name: pod.Name, Namespace: pod.Namespace, @@ -247,27 +271,35 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.In return ok } -func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) { +func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { + v, ok := ds.pods.LoadAndDelete(namespacedName) + if ok { + pmr := v.(backendmetrics.PodMetrics) + pmr.StopRefreshLoop() + } +} + +func (ds *datastore) podResyncAll(ctx context.Context, ctrlClient client.Client) error { logger := log.FromContext(ctx) podList := &corev1.PodList{} if err := ctrlClient.List(ctx, podList, &client.ListOptions{ - LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector), - Namespace: pool.Namespace, + LabelSelector: selectorFromInferencePoolSelector(ds.pool.Spec.Selector), + Namespace: ds.pool.Namespace, }); err != nil { - log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients") - return + return fmt.Errorf("failed to list pods - %w", err) } activePods := make(map[string]bool) for _, pod := range podList.Items { - if podutil.IsPodReady(&pod) { - namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} - activePods[pod.Name] = true - if ds.PodUpdateOrAddIfNotExist(&pod, pool) { - logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) - } else { - logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) - } + if !podutil.IsPodReady(&pod) { + continue + } + namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + activePods[pod.Name] = true + if ds.PodUpdateOrAddIfNotExist(&pod) { + logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) + } else { + logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) } } @@ -281,14 +313,8 @@ func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, return true } ds.pods.Range(deleteFn) -} -func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { - v, ok := ds.pods.LoadAndDelete(namespacedName) - if ok { - pmr := v.(backendmetrics.PodMetrics) - pmr.StopRefreshLoop() - } + return nil } func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector { diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 9e5d5821b..762aec2c0 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -29,8 +29,10 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/config" + "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" @@ -74,9 +76,15 @@ func TestPool(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) datastore := NewDatastore(context.Background(), pmf) - datastore.PoolSet(tt.inferencePool) + _ = datastore.PoolSet(context.Background(), fakeClient, tt.inferencePool) gotPool, gotErr := datastore.PoolGet() if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { t.Errorf("Unexpected error diff (+got/-want): %s", diff) @@ -323,11 +331,17 @@ func TestMetrics(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond) ds := NewDatastore(ctx, pmf) - ds.PoolSet(inferencePool) + _ = ds.PoolSet(ctx, fakeClient, inferencePool) for _, pod := range test.storePods { - ds.PodUpdateOrAddIfNotExist(pod, inferencePool) + ds.PodUpdateOrAddIfNotExist(pod) } assert.EventuallyWithT(t, func(t *assert.CollectT) { got := ds.PodGetAll() diff --git a/pkg/epp/util/pod/pod.go b/pkg/epp/util/pod/pod.go index 9f5640245..4fcb948fc 100644 --- a/pkg/epp/util/pod/pod.go +++ b/pkg/epp/util/pod/pod.go @@ -21,6 +21,9 @@ import ( ) func IsPodReady(pod *corev1.Pod) bool { + if !pod.DeletionTimestamp.IsZero() { + return false + } for _, condition := range pod.Status.Conditions { if condition.Type == corev1.PodReady { if condition.Status == corev1.ConditionTrue { From d167e49a2aaba8f40e550f0bb41b4760f74f77dd Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Wed, 23 Apr 2025 20:15:47 +0300 Subject: [PATCH 12/20] scheduler refactoring (#730) Signed-off-by: Nir Rozenbaum --- pkg/epp/backend/metrics/pod_metrics.go | 11 +- pkg/epp/backend/metrics/types.go | 15 +- .../scheduling/plugins/{ => filter}/filter.go | 85 +++++------ .../plugins/{ => filter}/filter_test.go | 38 ++--- pkg/epp/scheduling/plugins/noop.go | 12 +- .../{picker.go => picker/random_picker.go} | 6 +- .../interfaces.go => plugins/plugins.go} | 42 +++--- pkg/epp/scheduling/scheduler.go | 141 ++++++++---------- pkg/epp/scheduling/scheduler_test.go | 133 ++++++----------- pkg/epp/scheduling/types/types.go | 16 +- 10 files changed, 214 insertions(+), 285 deletions(-) rename pkg/epp/scheduling/plugins/{ => filter}/filter.go (81%) rename pkg/epp/scheduling/plugins/{ => filter}/filter_test.go (90%) rename pkg/epp/scheduling/plugins/{picker.go => picker/random_picker.go} (86%) rename pkg/epp/scheduling/{types/interfaces.go => plugins/plugins.go} (70%) diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index c85d4d794..7339389ad 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -41,9 +41,8 @@ type podMetrics struct { ds Datastore interval time.Duration - parentCtx context.Context - once sync.Once // ensure the StartRefreshLoop is only called once. - done chan struct{} + once sync.Once // ensure the StartRefreshLoop is only called once. + done chan struct{} logger logr.Logger } @@ -79,8 +78,8 @@ func toInternalPod(in *corev1.Pod) *Pod { } // start starts a goroutine exactly once to periodically update metrics. The goroutine will be -// stopped either when stop() is called, or the parentCtx is cancelled. -func (pm *podMetrics) startRefreshLoop() { +// stopped either when stop() is called, or the given ctx is cancelled. +func (pm *podMetrics) startRefreshLoop(ctx context.Context) { pm.once.Do(func() { go func() { pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod()) @@ -90,7 +89,7 @@ func (pm *podMetrics) startRefreshLoop() { select { case <-pm.done: return - case <-pm.parentCtx.Done(): + case <-ctx.Done(): return case <-ticker.C: // refresh metrics periodically if err := pm.refreshMetrics(); err != nil { diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index 21c0f4016..156ac3ed6 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -43,18 +43,17 @@ type PodMetricsFactory struct { func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics { pod := toInternalPod(in) pm := &podMetrics{ - pmc: f.pmc, - ds: ds, - interval: f.refreshMetricsInterval, - parentCtx: parentCtx, - once: sync.Once{}, - done: make(chan struct{}), - logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName), + pmc: f.pmc, + ds: ds, + interval: f.refreshMetricsInterval, + once: sync.Once{}, + done: make(chan struct{}), + logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName), } pm.pod.Store(pod) pm.metrics.Store(newMetrics()) - pm.startRefreshLoop() + pm.startRefreshLoop(parentCtx) return pm } diff --git a/pkg/epp/scheduling/plugins/filter.go b/pkg/epp/scheduling/plugins/filter/filter.go similarity index 81% rename from pkg/epp/scheduling/plugins/filter.go rename to pkg/epp/scheduling/plugins/filter/filter.go index efcb6be17..86620aa9f 100644 --- a/pkg/epp/scheduling/plugins/filter.go +++ b/pkg/epp/scheduling/plugins/filter/filter.go @@ -14,56 +14,55 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plugins +package filter import ( - "errors" "math" "math/rand" "time" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -type Filter struct { +type baseFilter struct { name string filter filterFunc } -func (bf *Filter) Name() string { - if bf == nil { +func (f *baseFilter) Name() string { + if f == nil { return "nil" } - return bf.name + return f.name } -func (bf *Filter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func (f *baseFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { loggerTrace := ctx.Logger.V(logutil.TRACE) - loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods)) + loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) - return bf.filter(ctx, pods) + return f.filter(ctx, pods) } // DecisionTreeFilter applies current filterFunc, and then recursively applies next filters // depending success or failure of the current filter. // It can be used to construct a flow chart algorithm. type DecisionTreeFilter struct { - Current types.Filter + Current plugins.Filter // NextOnSuccess filter will be applied after successfully applying the current filter. // The filtered results will be passed to the next filter. - NextOnSuccess types.Filter + NextOnSuccess plugins.Filter // NextOnFailure filter will be applied if current filter fails. // The original input will be passed to the next filter. - NextOnFailure types.Filter + NextOnFailure plugins.Filter // NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the // success or failure of the current filter. // NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of // NextOnSuccessOrFailure, in the success and failure scenarios, respectively. - NextOnSuccessOrFailure types.Filter + NextOnSuccessOrFailure plugins.Filter } func (f *DecisionTreeFilter) Name() string { @@ -73,15 +72,15 @@ func (f *DecisionTreeFilter) Name() string { return f.Current.Name() } -func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { loggerTrace := ctx.Logger.V(logutil.TRACE) - filtered, err := f.Current.Filter(ctx, pods) + filtered := f.Current.Filter(ctx, pods) next := f.NextOnSuccessOrFailure - if err == nil && len(filtered) > 0 { + if len(filtered) > 0 { if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. - return filtered, err + return filtered } if f.NextOnSuccess != nil { next = f.NextOnSuccess @@ -92,7 +91,7 @@ func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]typ } else { if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. - return filtered, err + return filtered } if f.NextOnFailure != nil { next = f.NextOnFailure @@ -104,11 +103,11 @@ func (f *DecisionTreeFilter) Filter(ctx *types.Context, pods []types.Pod) ([]typ } // filterFunc filters a set of input pods to a subset. -type filterFunc func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) +type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { + return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { filtered := []types.Pod{} for _, pod := range pods { pass := pp(ctx.Req, pod) @@ -116,14 +115,12 @@ func toFilterFunc(pp podPredicate) filterFunc { filtered = append(filtered, pod) } } - if len(filtered) == 0 { - return nil, errors.New("no pods left") - } - return filtered, nil + + return filtered } } -var LeastQueueFilter = &Filter{ +var LeastQueueFilter = &baseFilter{ name: "least queuing", filter: leastQueuingFilterFunc, } @@ -135,7 +132,7 @@ var LeastQueueFilter = &Filter{ // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { min := math.MaxInt max := 0 filtered := []types.Pod{} @@ -154,15 +151,15 @@ func leastQueuingFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, filtered = append(filtered, pod) } } - return filtered, nil + return filtered } -var LowQueueFilter = &Filter{ +var LowQueueFilter = &baseFilter{ name: "low queueing filter", filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))), } -var LeastKVCacheFilter = &Filter{ +var LeastKVCacheFilter = &baseFilter{ name: "least KV cache percent", filter: leastKVCacheFilterFunc, } @@ -173,7 +170,7 @@ var LeastKVCacheFilter = &Filter{ // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { min := math.MaxFloat64 var max float64 = 0 filtered := []types.Pod{} @@ -192,10 +189,10 @@ func leastKVCacheFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, filtered = append(filtered, pod) } } - return filtered, nil + return filtered } -var LoRAAffinityFilter = &Filter{ +var LoRAAffinityFilter = &baseFilter{ name: "affinity LoRA", filter: loRASoftAffinityFilterFunc, } @@ -216,7 +213,7 @@ var LoRAAffinityFilter = &Filter{ // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { // Pre-allocate slices with estimated capacity filtered_affinity := make([]types.Pod, 0, len(pods)) @@ -241,34 +238,24 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []types.Pod) ([]types.P // If both groups have pods, use probability to select which group to return if len(filtered_affinity) > 0 && len(filtered_available) > 0 { if randGen.Float64() < config.Conf.LoraAffinityThreshold { - return filtered_affinity, nil + return filtered_affinity } - return filtered_available, nil + return filtered_available } // Return whichever group has pods if len(filtered_affinity) > 0 { - return filtered_affinity, nil + return filtered_affinity } - return filtered_available, nil + return filtered_available } -var HasCapacityFilter = &Filter{ +var HasCapacityFilter = &baseFilter{ name: "has capacity for sheddable requests", filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), } -var DropRequestFilter = &Filter{ - name: "drop request", - filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { - ctx.Logger.V(logutil.DEFAULT).Info("Request dropped", "request", ctx.Req) - return []types.Pod{}, errutil.Error{ - Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", - } - }, -} - // podPredicate is a filter function to check whether a pod is desired. type podPredicate func(req *types.LLMRequest, pod types.Pod) bool diff --git a/pkg/epp/scheduling/plugins/filter_test.go b/pkg/epp/scheduling/plugins/filter/filter_test.go similarity index 90% rename from pkg/epp/scheduling/plugins/filter_test.go rename to pkg/epp/scheduling/plugins/filter/filter_test.go index 107b423fb..56cccb3b8 100644 --- a/pkg/epp/scheduling/plugins/filter_test.go +++ b/pkg/epp/scheduling/plugins/filter/filter_test.go @@ -14,11 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plugins +package filter import ( "context" - "errors" "testing" "github.com/google/go-cmp/cmp" @@ -34,30 +33,26 @@ func TestFilter(t *testing.T) { req *types.LLMRequest input []types.Pod output []types.Pod - err bool filter *DecisionTreeFilter }{ { - name: "simple filter without successor, failure", + name: "simple filter without available pods", filter: &DecisionTreeFilter{ - Current: &Filter{ - name: "error", - filter: func(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { - return nil, errors.New("filter error") + Current: &baseFilter{ + name: "filter all", + filter: func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + return []types.Pod{} }, }, }, - err: true, + output: []types.Pod{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewContext(context.Background(), test.req, test.input) - got, err := test.filter.Filter(ctx, test.input) - if test.err != (err != nil) { - t.Errorf("Unexpected error, got %v, want %v", err, test.err) - } + ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + got := test.filter.Filter(ctx, test.input) opt := cmp.AllowUnexported(types.PodMetrics{}) if diff := cmp.Diff(test.output, got, opt); diff != "" { @@ -74,7 +69,6 @@ func TestFilterFunc(t *testing.T) { req *types.LLMRequest input []types.Pod output []types.Pod - err bool }{ { name: "least queuing empty input", @@ -193,11 +187,8 @@ func TestFilterFunc(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := types.NewContext(context.Background(), test.req, test.input) - got, err := test.f(ctx, test.input) - if test.err != (err != nil) { - t.Errorf("Unexpected error, got %v, want %v", err, test.err) - } + ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + got := test.f(ctx, test.input) opt := cmp.AllowUnexported(types.PodMetrics{}) if diff := cmp.Diff(test.output, got, opt); diff != "" { @@ -254,7 +245,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { }, }, } - ctx := types.NewContext(context.Background(), req, pods) + ctx := types.NewSchedulingContext(context.Background(), req, pods) // Run the filter function multiple times and count the results affinityCount := 0 @@ -265,10 +256,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { expectedAvailabilityPercent := 100 - expectedAffinityPercent for i := 0; i < numIterations; i++ { - result, err := loRASoftAffinityFilterFunc(ctx, pods) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + result := loRASoftAffinityFilterFunc(ctx, pods) // Check which type of pod was returned if len(result) != 1 { diff --git a/pkg/epp/scheduling/plugins/noop.go b/pkg/epp/scheduling/plugins/noop.go index 1abcb95b1..8f50ff36e 100644 --- a/pkg/epp/scheduling/plugins/noop.go +++ b/pkg/epp/scheduling/plugins/noop.go @@ -27,12 +27,16 @@ type NoopPlugin struct{} func (p *NoopPlugin) Name() string { return "NoopPlugin" } -func (p *NoopPlugin) Score(ctx *types.Context, pod types.Pod) (float64, error) { return 0.0, nil } +func (p *NoopPlugin) PreSchedule(ctx *types.SchedulingContext) {} -func (p *NoopPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func (p *NoopPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) ([]types.Pod, error) { return pods, nil } -func (p *NoopPlugin) PreSchedule(ctx *types.Context) {} +func (p *NoopPlugin) Score(ctx *types.SchedulingContext, pod types.Pod) (float64, error) { + return 0.0, nil +} + +func (p *NoopPlugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) {} -func (p *NoopPlugin) PostSchedule(ctx *types.Context, res *types.Result) {} +func (p *NoopPlugin) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {} diff --git a/pkg/epp/scheduling/plugins/picker.go b/pkg/epp/scheduling/plugins/picker/random_picker.go similarity index 86% rename from pkg/epp/scheduling/plugins/picker.go rename to pkg/epp/scheduling/plugins/picker/random_picker.go index 569e4e86a..850108e7e 100644 --- a/pkg/epp/scheduling/plugins/picker.go +++ b/pkg/epp/scheduling/plugins/picker/random_picker.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package plugins +package picker import ( "fmt" @@ -30,8 +30,8 @@ func (rp *RandomPicker) Name() string { return "random" } -func (rp *RandomPicker) Pick(ctx *types.Context, pods []types.Pod) (*types.Result, error) { +func (rp *RandomPicker) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result { ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) i := rand.Intn(len(pods)) - return &types.Result{TargetPod: pods[i]}, nil + return &types.Result{TargetPod: pods[i]} } diff --git a/pkg/epp/scheduling/types/interfaces.go b/pkg/epp/scheduling/plugins/plugins.go similarity index 70% rename from pkg/epp/scheduling/types/interfaces.go rename to pkg/epp/scheduling/plugins/plugins.go index 6e954cef0..4b334803b 100644 --- a/pkg/epp/scheduling/types/interfaces.go +++ b/pkg/epp/scheduling/plugins/plugins.go @@ -14,28 +14,21 @@ See the License for the specific language governing permissions and limitations under the License. */ -package types +package plugins import ( - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) const ( PreSchedulerPluginType = "PreSchedule" - PostSchedulePluginType = "PostSchedule" FilterPluginType = "Filter" ScorerPluginType = "Scorer" + PostSchedulePluginType = "PostSchedule" PickerPluginType = "Picker" + PostResponsePluginType = "PostResponse" ) -type Pod interface { - GetPod() *backendmetrics.Pod - GetMetrics() *backendmetrics.Metrics - SetScore(float64) - Score() float64 - String() string -} - // Plugin defines the interface for scheduler plugins, combining scoring, filtering, // and event handling capabilities. type Plugin interface { @@ -47,29 +40,36 @@ type Plugin interface { // initialization work. type PreSchedule interface { Plugin - PreSchedule(ctx *Context) -} - -// PostSchedule is called by the scheduler after it selects a targetPod for the request. -type PostSchedule interface { - Plugin - PostSchedule(ctx *Context, res *Result) + PreSchedule(ctx *types.SchedulingContext) } // Filter defines the interface for filtering a list of pods based on context. type Filter interface { Plugin - Filter(ctx *Context, pods []Pod) ([]Pod, error) + Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod } // Scorer defines the interface for scoring pods based on context. type Scorer interface { Plugin - Score(ctx *Context, pod Pod) (float64, error) + Score(ctx *types.SchedulingContext, pod types.Pod) float64 +} + +// PostSchedule is called by the scheduler after it selects a targetPod for the request. +type PostSchedule interface { + Plugin + PostSchedule(ctx *types.SchedulingContext, res *types.Result) } // Picker picks the final pod(s) to send the request to. type Picker interface { Plugin - Pick(ctx *Context, pods []Pod) (*Result, error) + Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result +} + +// PostResponse is called by the scheduler after a successful response was sent. +// The given pod argument is the pod that served the request. +type PostResponse interface { + Plugin + PostResponse(ctx *types.SchedulingContext, pod types.Pod) } diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 7cc2bd968..beac5e6b8 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -26,42 +26,44 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) var ( - lowLatencyFilter = &plugins.DecisionTreeFilter{ - Current: plugins.LowQueueFilter, - NextOnSuccess: &plugins.DecisionTreeFilter{ - Current: plugins.LoRAAffinityFilter, - NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ - Current: plugins.LeastQueueFilter, - NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ - Current: plugins.LeastKVCacheFilter, + lowLatencyFilter = &filter.DecisionTreeFilter{ + Current: filter.LowQueueFilter, + NextOnSuccess: &filter.DecisionTreeFilter{ + Current: filter.LoRAAffinityFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastQueueFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastKVCacheFilter, }, }, }, - NextOnFailure: &plugins.DecisionTreeFilter{ - Current: plugins.LeastQueueFilter, - NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ - Current: plugins.LoRAAffinityFilter, - NextOnSuccessOrFailure: &plugins.DecisionTreeFilter{ - Current: plugins.LeastKVCacheFilter, + NextOnFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastQueueFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LoRAAffinityFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastKVCacheFilter, }, }, }, } - sheddableRequestFilter = &plugins.DecisionTreeFilter{ + sheddableRequestFilter = &filter.DecisionTreeFilter{ // When there is at least one model server that's not queuing requests, and still has KV // cache below a certain threshold, we consider this model server has capacity to handle // a sheddable request without impacting critical requests. - Current: plugins.HasCapacityFilter, + Current: filter.HasCapacityFilter, NextOnSuccess: lowLatencyFilter, // If all pods are queuing or running above the KVCache threshold, we drop the sheddable - // request to make room for critical requests. - NextOnFailure: plugins.DropRequestFilter, + // request to make room for critical requests. for this, we don't define nextOnFailure. } ) @@ -70,21 +72,21 @@ func NewScheduler(datastore Datastore) *Scheduler { return &Scheduler{ datastore: datastore, - preSchedulePlugins: []types.PreSchedule{}, - postSchedulePlugins: []types.PostSchedule{}, - scorers: []types.Scorer{}, - filters: []types.Filter{defaultPlugin}, + preSchedulePlugins: []plugins.PreSchedule{}, + scorers: []plugins.Scorer{}, + filters: []plugins.Filter{defaultPlugin}, + postSchedulePlugins: []plugins.PostSchedule{}, picker: defaultPlugin, } } type Scheduler struct { datastore Datastore - preSchedulePlugins []types.PreSchedule - postSchedulePlugins []types.PostSchedule - filters []types.Filter - scorers []types.Scorer - picker types.Picker + preSchedulePlugins []plugins.PreSchedule + filters []plugins.Filter + scorers []plugins.Scorer + postSchedulePlugins []plugins.PostSchedule + picker plugins.Picker } type Datastore interface { @@ -99,26 +101,21 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) + sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) loggerDebug.Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", sCtx.PodsSnapshot)) s.runPreSchedulePlugins(sCtx) - pods, err := s.runFilterPlugins(sCtx) - if err != nil { - return nil, err + pods := s.runFilterPlugins(sCtx) + if len(pods) == 0 { + return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: "failed to find a target pod"} } - if err := s.runScorerPlugins(sCtx, pods); err != nil { - return nil, err - } + s.runScorerPlugins(sCtx, pods) before := time.Now() - res, err := s.picker.Pick(sCtx, pods) - metrics.RecordSchedulerPluginProcessingLatency(types.PickerPluginType, s.picker.Name(), time.Since(before)) - if err != nil { - return nil, err - } + res := s.picker.Pick(sCtx, pods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before)) loggerDebug.Info("After running picker plugins", "result", res) s.runPostSchedulePlugins(sCtx, res) @@ -126,91 +123,79 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types return res, nil } -func (s *Scheduler) runPreSchedulePlugins(ctx *types.Context) { +func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) { for _, plugin := range s.preSchedulePlugins { ctx.Logger.V(logutil.DEBUG).Info("Running pre-schedule plugin", "plugin", plugin.Name()) before := time.Now() plugin.PreSchedule(ctx) - metrics.RecordSchedulerPluginProcessingLatency(types.PreSchedulerPluginType, plugin.Name(), time.Since(before)) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PreSchedulerPluginType, plugin.Name(), time.Since(before)) } } -func (s *Scheduler) runPostSchedulePlugins(ctx *types.Context, res *types.Result) { +func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) { for _, plugin := range s.postSchedulePlugins { ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name()) before := time.Now() plugin.PostSchedule(ctx, res) - metrics.RecordSchedulerPluginProcessingLatency(types.PostSchedulePluginType, plugin.Name(), time.Since(before)) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before)) } } -func (s *Scheduler) runFilterPlugins(ctx *types.Context) ([]types.Pod, error) { +func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod { loggerDebug := ctx.Logger.V(logutil.DEBUG) - pods := ctx.PodsSnapshot - loggerDebug.Info("Before running filter plugins", "pods", pods) + filteredPods := ctx.PodsSnapshot + loggerDebug.Info("Before running filter plugins", "pods", filteredPods) + for _, filter := range s.filters { loggerDebug.Info("Running filter plugin", "plugin", filter.Name()) before := time.Now() - filteredPods, err := filter.Filter(ctx, pods) - metrics.RecordSchedulerPluginProcessingLatency(types.FilterPluginType, filter.Name(), time.Since(before)) - if err != nil || len(filteredPods) == 0 { - return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(filteredPods), err) + filteredPods = filter.Filter(ctx, filteredPods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.FilterPluginType, filter.Name(), time.Since(before)) + loggerDebug.Info("Filter plugin result", "plugin", filter.Name(), "pods", filteredPods) + if len(filteredPods) == 0 { + break } - pods = filteredPods - loggerDebug.Info("Filter plugin result", "plugin", filter.Name(), "pods", pods) } - loggerDebug.Info("After running filter plugins", "pods", pods) - return pods, nil + return filteredPods } -func (s *Scheduler) runScorerPlugins(ctx *types.Context, pods []types.Pod) error { +func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) { loggerDebug := ctx.Logger.V(logutil.DEBUG) loggerDebug.Info("Before running score plugins", "pods", pods) for _, pod := range pods { - score, err := runScorersForPod(ctx, s.scorers, pod) - if err != nil { - return err - } + score := s.runScorersForPod(ctx, pod) pod.SetScore(score) } loggerDebug.Info("After running score plugins", "pods", pods) - return nil } // Iterate through each scorer in the chain and accumulate the scores. -func runScorersForPod(ctx *types.Context, scorers []types.Scorer, pod types.Pod) (float64, error) { +func (s *Scheduler) runScorersForPod(ctx *types.SchedulingContext, pod types.Pod) float64 { logger := ctx.Logger.WithValues("pod", pod.GetPod().NamespacedName).V(logutil.DEBUG) score := float64(0) - for _, scorer := range scorers { + for _, scorer := range s.scorers { logger.Info("Running scorer", "scorer", scorer.Name()) before := time.Now() - oneScore, err := scorer.Score(ctx, pod) - metrics.RecordSchedulerPluginProcessingLatency(types.ScorerPluginType, scorer.Name(), time.Since(before)) - if err != nil { - logger.Error(err, "Failed to calculate score for scorer", "scorer", scorer.Name()) - return 0, err - } + oneScore := scorer.Score(ctx, pod) + metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before)) score += oneScore logger.Info("After scorer", "scorer", scorer.Name(), "score", oneScore, "total score", score) } - return score, nil + return score } type defaultPlugin struct { - plugins.RandomPicker + picker.RandomPicker } func (p *defaultPlugin) Name() string { return "DefaultPlugin" } -func (p *defaultPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { - req := ctx.Req - var filter types.Filter - if req.Critical { - filter = lowLatencyFilter - } else { - filter = sheddableRequestFilter +func (p *defaultPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + if ctx.Req.Critical { + return lowLatencyFilter.Filter(ctx, pods) } - return filter.Filter(ctx, pods) + + return sheddableRequestFilter.Filter(ctx, pods) } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index 5a2265bff..cb729038e 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -18,12 +18,12 @@ package scheduling import ( "context" - "errors" "testing" "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) @@ -247,30 +247,22 @@ func TestSchedulePlugins(t *testing.T) { ScoreRes: 0.8, FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, } - tpFilterErr := &TestPlugin{ - NameRes: "filter err", - FilterErr: errors.New("filter error"), - } - tpScorerErr := &TestPlugin{ - NameRes: "score err", - ScoreErr: errors.New("score err"), + tp_filterAll := &TestPlugin{ + NameRes: "filter all", + FilterRes: []k8stypes.NamespacedName{}, } pickerPlugin := &TestPlugin{ NameRes: "picker", PickRes: k8stypes.NamespacedName{Name: "pod1"}, } - pickerErr := &TestPlugin{ - NameRes: "picker err", - PickErr: errors.New("picker err"), - } tests := []struct { name string - preSchedulePlugins []types.PreSchedule - postSchedulePlugins []types.PostSchedule - filters []types.Filter - scorers []types.Scorer - picker types.Picker + preSchedulePlugins []plugins.PreSchedule + filters []plugins.Filter + scorers []plugins.Scorer + postSchedulePlugins []plugins.PostSchedule + picker plugins.Picker input []*backendmetrics.FakePodMetrics wantTargetPod k8stypes.NamespacedName targetPodScore float64 @@ -280,10 +272,10 @@ func TestSchedulePlugins(t *testing.T) { }{ { name: "all plugins executed successfully", - preSchedulePlugins: []types.PreSchedule{tp1, tp2}, - postSchedulePlugins: []types.PostSchedule{tp1, tp2}, - filters: []types.Filter{tp1, tp2}, - scorers: []types.Scorer{tp1, tp2}, + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp2}, + scorers: []plugins.Scorer{tp1, tp2}, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, picker: pickerPlugin, input: []*backendmetrics.FakePodMetrics{ {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, @@ -296,46 +288,19 @@ func TestSchedulePlugins(t *testing.T) { err: false, }, { - name: "filter error", - preSchedulePlugins: []types.PreSchedule{tp1, tp2}, - postSchedulePlugins: []types.PostSchedule{tp1, tp2}, - filters: []types.Filter{tp1, tpFilterErr}, - scorers: []types.Scorer{tp1, tp2}, - picker: pickerPlugin, - input: []*backendmetrics.FakePodMetrics{ - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, - }, - err: true, - }, - { - name: "scorer error", - preSchedulePlugins: []types.PreSchedule{tp1, tp2}, - postSchedulePlugins: []types.PostSchedule{tp1, tp2}, - filters: []types.Filter{tp1, tp2}, - scorers: []types.Scorer{tp1, tpScorerErr}, + name: "filter all", + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp_filterAll}, + scorers: []plugins.Scorer{tp1, tp2}, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, picker: pickerPlugin, input: []*backendmetrics.FakePodMetrics{ {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, }, - err: true, - }, - { - name: "picker error", - preSchedulePlugins: []types.PreSchedule{tp1, tp2}, - postSchedulePlugins: []types.PostSchedule{tp1, tp2}, - filters: []types.Filter{tp1, tp2}, - scorers: []types.Scorer{tp1, tp2}, - picker: pickerErr, - input: []*backendmetrics.FakePodMetrics{ - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, - {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, - }, - err: true, + numPodsToScore: 0, + err: true, // no available pods to server after filter all }, } @@ -343,26 +308,26 @@ func TestSchedulePlugins(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Reset all plugins before each new test case. for _, plugin := range test.preSchedulePlugins { - plugin.(*TestPlugin).Reset() + plugin.(*TestPlugin).reset() } for _, plugin := range test.postSchedulePlugins { - plugin.(*TestPlugin).Reset() + plugin.(*TestPlugin).reset() } for _, plugin := range test.filters { - plugin.(*TestPlugin).Reset() + plugin.(*TestPlugin).reset() } for _, plugin := range test.scorers { - plugin.(*TestPlugin).Reset() + plugin.(*TestPlugin).reset() } - test.picker.(*TestPlugin).Reset() + test.picker.(*TestPlugin).reset() // Initialize the scheduler scheduler := &Scheduler{ datastore: &fakeDataStore{pods: test.input}, preSchedulePlugins: test.preSchedulePlugins, - postSchedulePlugins: test.postSchedulePlugins, filters: test.filters, scorers: test.scorers, + postSchedulePlugins: test.postSchedulePlugins, picker: test.picker, } @@ -397,13 +362,6 @@ func TestSchedulePlugins(t *testing.T) { } } - for _, plugin := range test.postSchedulePlugins { - tp, _ := plugin.(*TestPlugin) - if tp.PostScheduleCallCount != 1 { - t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", tp.NameRes, tp.PostScheduleCallCount) - } - } - for _, plugin := range test.filters { tp, _ := plugin.(*TestPlugin) if tp.FilterCallCount != 1 { @@ -418,6 +376,13 @@ func TestSchedulePlugins(t *testing.T) { } } + for _, plugin := range test.postSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PostScheduleCallCount != 1 { + t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", tp.NameRes, tp.PostScheduleCallCount) + } + } + tp, _ := test.picker.(*TestPlugin) if tp.PickCallCount != 1 { t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.NameRes, tp.PickCallCount) @@ -444,55 +409,49 @@ type TestPlugin struct { NameRes string ScoreCallCount int ScoreRes float64 - ScoreErr error FilterCallCount int FilterRes []k8stypes.NamespacedName - FilterErr error PreScheduleCallCount int PostScheduleCallCount int PickCallCount int PickRes k8stypes.NamespacedName - PickErr error } func (tp *TestPlugin) Name() string { return tp.NameRes } -func (tp *TestPlugin) Score(ctx *types.Context, pod types.Pod) (float64, error) { - tp.ScoreCallCount++ - return tp.ScoreRes, tp.ScoreErr +func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { + tp.PreScheduleCallCount++ } -func (tp *TestPlugin) Filter(ctx *types.Context, pods []types.Pod) ([]types.Pod, error) { +func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { tp.FilterCallCount++ - return findPods(ctx, tp.FilterRes...), tp.FilterErr + return findPods(ctx, tp.FilterRes...) } -func (tp *TestPlugin) PreSchedule(ctx *types.Context) { - tp.PreScheduleCallCount++ +func (tp *TestPlugin) Score(ctx *types.SchedulingContext, pod types.Pod) float64 { + tp.ScoreCallCount++ + return tp.ScoreRes } -func (tp *TestPlugin) PostSchedule(ctx *types.Context, res *types.Result) { +func (tp *TestPlugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { tp.PostScheduleCallCount++ } -func (tp *TestPlugin) Pick(ctx *types.Context, pods []types.Pod) (*types.Result, error) { +func (tp *TestPlugin) Pick(ctx *types.SchedulingContext, pods []types.Pod) *types.Result { tp.PickCallCount++ - if tp.PickErr != nil { - return nil, tp.PickErr - } pod := findPods(ctx, tp.PickRes)[0] - return &types.Result{TargetPod: pod}, nil + return &types.Result{TargetPod: pod} } -func (tp *TestPlugin) Reset() { +func (tp *TestPlugin) reset() { tp.PreScheduleCallCount = 0 - tp.PostScheduleCallCount = 0 tp.FilterCallCount = 0 tp.ScoreCallCount = 0 + tp.PostScheduleCallCount = 0 tp.PickCallCount = 0 } -func findPods(ctx *types.Context, names ...k8stypes.NamespacedName) []types.Pod { +func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { res := []types.Pod{} for _, pod := range ctx.PodsSnapshot { for _, name := range names { diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index e52e90472..e66b5fb5d 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -40,8 +40,16 @@ func (r *LLMRequest) String() string { return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, len(r.Prompt)) } -// Context holds contextual information during a scheduling operation. -type Context struct { +type Pod interface { + GetPod() *backendmetrics.Pod + GetMetrics() *backendmetrics.Metrics + SetScore(float64) + Score() float64 + String() string +} + +// SchedulingContext holds contextual information during a scheduling operation. +type SchedulingContext struct { context.Context Logger logr.Logger Req *LLMRequest @@ -77,9 +85,9 @@ type PodMetrics struct { *backendmetrics.Metrics } -func NewContext(ctx context.Context, req *LLMRequest, pods []Pod) *Context { +func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) - return &Context{ + return &SchedulingContext{ Context: ctx, Logger: logger, Req: req, From c262a1d3950cb7c27c2f36088309825f722ad05d Mon Sep 17 00:00:00 2001 From: nayihz Date: Thu, 24 Apr 2025 01:41:46 +0800 Subject: [PATCH 13/20] filter irrelevant pod in pod_reconciler (#696) --- pkg/epp/controller/pod_reconciler.go | 22 ++++++++++++++++++++++ pkg/epp/datastore/datastore.go | 3 +++ 2 files changed, 25 insertions(+) diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 6d1af8d9a..5f1df10d7 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -26,7 +26,9 @@ import ( "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" @@ -63,8 +65,28 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R } func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { + filter := predicate.Funcs{ + CreateFunc: func(ce event.CreateEvent) bool { + pod := ce.Object.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(pod.GetLabels()) + }, + UpdateFunc: func(ue event.UpdateEvent) bool { + oldPod := ue.ObjectOld.(*corev1.Pod) + newPod := ue.ObjectNew.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(oldPod.GetLabels()) || c.Datastore.PoolLabelsMatch(newPod.GetLabels()) + }, + DeleteFunc: func(de event.DeleteEvent) bool { + pod := de.Object.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(pod.GetLabels()) + }, + GenericFunc: func(ge event.GenericEvent) bool { + pod := ge.Object.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(pod.GetLabels()) + }, + } return ctrl.NewControllerManagedBy(mgr). For(&corev1.Pod{}). + WithEventFilter(filter). Complete(c) } diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index f8378d250..22c500220 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -150,6 +150,9 @@ func (ds *datastore) PoolHasSynced() bool { func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { ds.poolAndModelsMu.RLock() defer ds.poolAndModelsMu.RUnlock() + if ds.pool == nil { + return false + } poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector) podSet := labels.Set(podLabels) return poolSelector.Matches(podSet) From 4707ab2954a4dbbe24e070e01442ec3cdae2a81d Mon Sep 17 00:00:00 2001 From: Daneyon Hansen Date: Wed, 23 Apr 2025 14:30:31 -0700 Subject: [PATCH 14/20] EPP: Update GetRandomPod() to return nil if no pods exist (#731) Signed-off-by: Daneyon Hansen --- pkg/epp/handlers/request.go | 3 ++ pkg/epp/handlers/server.go | 3 ++ pkg/epp/handlers/streamingserver_test.go | 55 ++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 9121b59af..8d30e543d 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -138,6 +138,9 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ // The above PR will address endpoint admission, but currently any request without a body will be // routed to a random upstream pod. pod := GetRandomPod(s.datastore) + if pod == nil { + return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"} + } pool, err := s.datastore.PoolGet() if err != nil { return err diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 2e3a35fe7..5e23c7a0a 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -449,6 +449,9 @@ func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed func GetRandomPod(ds datastore.Datastore) *backendmetrics.Pod { pods := ds.PodGetAll() + if len(pods) == 0 { + return nil + } number := rand.Intn(len(pods)) pod := pods[number] return pod.GetPod() diff --git a/pkg/epp/handlers/streamingserver_test.go b/pkg/epp/handlers/streamingserver_test.go index 72f7031a4..23d2b68fa 100644 --- a/pkg/epp/handlers/streamingserver_test.go +++ b/pkg/epp/handlers/streamingserver_test.go @@ -18,8 +18,14 @@ package handlers import ( "testing" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -126,6 +132,55 @@ func TestRandomWeightedDraw(t *testing.T) { } } +func TestGetRandomPod(t *testing.T) { + tests := []struct { + name string + storePods []*corev1.Pod + expectNil bool + }{ + { + name: "No pods available", + storePods: []*corev1.Pod{}, + expectNil: true, + }, + { + name: "Single pod available", + storePods: []*corev1.Pod{ + {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, + }, + expectNil: false, + }, + { + name: "Multiple pods available", + storePods: []*corev1.Pod{ + {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "pod3"}}, + }, + expectNil: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pmf := metrics.NewPodMetricsFactory(&metrics.FakePodMetricsClient{}, time.Millisecond) + ds := datastore.NewDatastore(t.Context(), pmf) + for _, pod := range test.storePods { + ds.PodUpdateOrAddIfNotExist(pod) + } + + gotPod := GetRandomPod(ds) + + if test.expectNil && gotPod != nil { + t.Errorf("expected nil pod, got: %v", gotPod) + } + if !test.expectNil && gotPod == nil { + t.Errorf("expected non-nil pod, got nil") + } + }) + } +} + func pointer(v int32) *int32 { return &v } From ec4af3b1ed48d2bb9c64cdb98208e74e8b28418a Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Thu, 24 Apr 2025 18:48:31 +0300 Subject: [PATCH 15/20] Move filter and scorer plugins registration to a separate file (#729) * Move filters and scorers registration to filter/scorer specific files * Default scheduler config contains empty list of scorers Signed-off-by: Maya Barnea * Default plugin is not a scorer any more Signed-off-by: Maya Barnea * fix scheduler test + lint comments Signed-off-by: Maya Barnea --------- Signed-off-by: Maya Barnea --- pkg/epp/scheduling/config.go | 27 ++++++++++ pkg/epp/scheduling/default_config.go | 31 +++++++++++ pkg/epp/scheduling/scheduler.go | 18 ++++--- pkg/epp/scheduling/scheduler_test.go | 81 ++++++++++++++-------------- 4 files changed, 110 insertions(+), 47 deletions(-) create mode 100644 pkg/epp/scheduling/config.go create mode 100644 pkg/epp/scheduling/default_config.go diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go new file mode 100644 index 000000000..6c0f4be7b --- /dev/null +++ b/pkg/epp/scheduling/config.go @@ -0,0 +1,27 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + +type SchedulerConfig struct { + preSchedulePlugins []plugins.PreSchedule + scorers []plugins.Scorer + filters []plugins.Filter + postSchedulePlugins []plugins.PostSchedule + picker plugins.Picker +} diff --git a/pkg/epp/scheduling/default_config.go b/pkg/epp/scheduling/default_config.go new file mode 100644 index 000000000..e42f13179 --- /dev/null +++ b/pkg/epp/scheduling/default_config.go @@ -0,0 +1,31 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" +) + +var defPlugin = &defaultPlugin{} + +var defaultConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + scorers: []plugins.Scorer{}, + filters: []plugins.Filter{defPlugin}, + postSchedulePlugins: []plugins.PostSchedule{}, + picker: defPlugin, +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index beac5e6b8..322f714f4 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -68,16 +68,20 @@ var ( ) func NewScheduler(datastore Datastore) *Scheduler { - defaultPlugin := &defaultPlugin{} + return NewSchedulerWithConfig(datastore, defaultConfig) +} - return &Scheduler{ +func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Scheduler { + scheduler := &Scheduler{ datastore: datastore, - preSchedulePlugins: []plugins.PreSchedule{}, - scorers: []plugins.Scorer{}, - filters: []plugins.Filter{defaultPlugin}, - postSchedulePlugins: []plugins.PostSchedule{}, - picker: defaultPlugin, + preSchedulePlugins: config.preSchedulePlugins, + scorers: config.scorers, + filters: config.filters, + postSchedulePlugins: config.postSchedulePlugins, + picker: config.picker, } + + return scheduler } type Scheduler struct { diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index cb729038e..2fb26a865 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -220,9 +220,17 @@ func TestSchedule(t *testing.T) { }, } + schedConfig := &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + scorers: []plugins.Scorer{}, + filters: []plugins.Filter{defPlugin}, + postSchedulePlugins: []plugins.PostSchedule{}, + picker: defPlugin, + } + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - scheduler := NewScheduler(&fakeDataStore{pods: test.input}) + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, schedConfig) got, err := scheduler.Schedule(context.Background(), test.req) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) @@ -257,26 +265,24 @@ func TestSchedulePlugins(t *testing.T) { } tests := []struct { - name string - preSchedulePlugins []plugins.PreSchedule - filters []plugins.Filter - scorers []plugins.Scorer - postSchedulePlugins []plugins.PostSchedule - picker plugins.Picker - input []*backendmetrics.FakePodMetrics - wantTargetPod k8stypes.NamespacedName - targetPodScore float64 + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + wantTargetPod k8stypes.NamespacedName + targetPodScore float64 // Number of expected pods to score (after filter) numPodsToScore int err bool }{ { - name: "all plugins executed successfully", - preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, - filters: []plugins.Filter{tp1, tp2}, - scorers: []plugins.Scorer{tp1, tp2}, - postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, - picker: pickerPlugin, + name: "all plugins executed successfully", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp2}, + scorers: []plugins.Scorer{tp1, tp2}, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + picker: pickerPlugin, + }, input: []*backendmetrics.FakePodMetrics{ {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -288,12 +294,14 @@ func TestSchedulePlugins(t *testing.T) { err: false, }, { - name: "filter all", - preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, - filters: []plugins.Filter{tp1, tp_filterAll}, - scorers: []plugins.Scorer{tp1, tp2}, - postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, - picker: pickerPlugin, + name: "filter all", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp_filterAll}, + scorers: []plugins.Scorer{tp1, tp2}, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + picker: pickerPlugin, + }, input: []*backendmetrics.FakePodMetrics{ {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, @@ -307,29 +315,22 @@ func TestSchedulePlugins(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Reset all plugins before each new test case. - for _, plugin := range test.preSchedulePlugins { + for _, plugin := range test.config.preSchedulePlugins { plugin.(*TestPlugin).reset() } - for _, plugin := range test.postSchedulePlugins { + for _, plugin := range test.config.postSchedulePlugins { plugin.(*TestPlugin).reset() } - for _, plugin := range test.filters { + for _, plugin := range test.config.filters { plugin.(*TestPlugin).reset() } - for _, plugin := range test.scorers { + for _, plugin := range test.config.scorers { plugin.(*TestPlugin).reset() } - test.picker.(*TestPlugin).reset() + test.config.picker.(*TestPlugin).reset() // Initialize the scheduler - scheduler := &Scheduler{ - datastore: &fakeDataStore{pods: test.input}, - preSchedulePlugins: test.preSchedulePlugins, - filters: test.filters, - scorers: test.scorers, - postSchedulePlugins: test.postSchedulePlugins, - picker: test.picker, - } + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) req := &types.LLMRequest{Model: "test-model"} got, err := scheduler.Schedule(context.Background(), req) @@ -355,35 +356,35 @@ func TestSchedulePlugins(t *testing.T) { } // Validate plugin execution counts dynamically - for _, plugin := range test.preSchedulePlugins { + for _, plugin := range test.config.preSchedulePlugins { tp, _ := plugin.(*TestPlugin) if tp.PreScheduleCallCount != 1 { t.Errorf("Plugin %s PreSchedule() called %d times, expected 1", tp.NameRes, tp.PreScheduleCallCount) } } - for _, plugin := range test.filters { + for _, plugin := range test.config.filters { tp, _ := plugin.(*TestPlugin) if tp.FilterCallCount != 1 { t.Errorf("Plugin %s Filter() called %d times, expected 1", tp.NameRes, tp.FilterCallCount) } } - for _, plugin := range test.scorers { + for _, plugin := range test.config.scorers { tp, _ := plugin.(*TestPlugin) if tp.ScoreCallCount != test.numPodsToScore { t.Errorf("Plugin %s Score() called %d times, expected 1", tp.NameRes, tp.ScoreCallCount) } } - for _, plugin := range test.postSchedulePlugins { + for _, plugin := range test.config.postSchedulePlugins { tp, _ := plugin.(*TestPlugin) if tp.PostScheduleCallCount != 1 { t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", tp.NameRes, tp.PostScheduleCallCount) } } - tp, _ := test.picker.(*TestPlugin) + tp, _ := test.config.picker.(*TestPlugin) if tp.PickCallCount != 1 { t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.NameRes, tp.PickCallCount) } From e03802a4a602cae74a16563d838b24e7ca70b5be Mon Sep 17 00:00:00 2001 From: Kellen Swain Date: Thu, 24 Apr 2025 14:20:30 -0700 Subject: [PATCH 16/20] Update issue templates (#738) * Update issue templates * Updates artifacts for v0.3.0-rc.1 release Signed-off-by: Kellen Swain * Updates bbr chart for v0.3.0-rc.1 release Signed-off-by: Kellen Swain * Updates artifacts for v0.3.0 release Signed-off-by: Kellen Swain * Adding blank issue template so that all issues start with label --------- Signed-off-by: Kellen Swain --- .github/ISSUE_TEMPLATE/bug_request.md | 4 +++- .github/ISSUE_TEMPLATE/feature_request.md | 3 +-- .github/ISSUE_TEMPLATE/issue_template.md | 8 ++++++++ .github/ISSUE_TEMPLATE/new-release.md | 1 + 4 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/issue_template.md diff --git a/.github/ISSUE_TEMPLATE/bug_request.md b/.github/ISSUE_TEMPLATE/bug_request.md index c2597eb32..15ed35e12 100644 --- a/.github/ISSUE_TEMPLATE/bug_request.md +++ b/.github/ISSUE_TEMPLATE/bug_request.md @@ -1,7 +1,9 @@ --- name: Bug Report about: Report a bug you encountered -labels: kind/bug +title: '' +labels: kind/bug, needs-triage +assignees: '' --- diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 53a885c7c..1eee5871b 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -2,7 +2,7 @@ name: Feature request about: Suggest an idea for this project title: '' -labels: '' +labels: needs-triage assignees: '' --- @@ -12,4 +12,3 @@ assignees: '' **What would you like to be added**: **Why is this needed**: - diff --git a/.github/ISSUE_TEMPLATE/issue_template.md b/.github/ISSUE_TEMPLATE/issue_template.md new file mode 100644 index 000000000..1a2c8c6fc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue_template.md @@ -0,0 +1,8 @@ +--- +name: Blank Issue +about: '' +title: '' +labels: needs-triage +assignees: '' + +--- \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/new-release.md b/.github/ISSUE_TEMPLATE/new-release.md index be5698441..27e837845 100644 --- a/.github/ISSUE_TEMPLATE/new-release.md +++ b/.github/ISSUE_TEMPLATE/new-release.md @@ -4,6 +4,7 @@ about: Propose a new release title: Release v0.x.0 labels: '' assignees: '' + --- - [Introduction](#introduction) From f93cbe6f0f06a09ae8c15072597e032ec84dcd32 Mon Sep 17 00:00:00 2001 From: Radhika Lakhtakia Date: Fri, 18 Apr 2025 18:42:46 +0000 Subject: [PATCH 17/20] Add unit test coverage for pod APIs under datastore/pkg --- pkg/epp/datastore/datastore_test.go | 34 +++++------------------------ 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 762aec2c0..b6466e6b2 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -30,8 +30,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/client/config" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -359,10 +357,6 @@ func TestMetrics(t *testing.T) { } func TestPods(t *testing.T) { - poolSelector := map[string]string{"app": "vllm_v1"} - pool := testutil.MakeInferencePool("pool"). - Namespace("default"). - Selector(poolSelector).ObjRef() updatedPod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: "pod1", @@ -371,10 +365,6 @@ func TestPods(t *testing.T) { NodeName: "node-1", }, } - resyncPoolSelector := map[string]string{"app": "llama3_8b"} - resyncPool := testutil.MakeInferencePool("pool"). - Namespace("default"). - Selector(resyncPoolSelector).ObjRef() tests := []struct { name string op func(ctx context.Context, ds Datastore) @@ -386,7 +376,7 @@ func TestPods(t *testing.T) { existingPods: []*corev1.Pod{}, wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(pod1, pool) + ds.PodUpdateOrAddIfNotExist(pod1) }, }, { @@ -394,7 +384,7 @@ func TestPods(t *testing.T) { existingPods: []*corev1.Pod{pod1}, wantPods: []*corev1.Pod{pod1, pod2}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(pod2, pool) + ds.PodUpdateOrAddIfNotExist(pod2) }, }, { @@ -402,7 +392,7 @@ func TestPods(t *testing.T) { existingPods: []*corev1.Pod{pod1}, wantPods: []*corev1.Pod{updatedPod}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(updatedPod, pool) + ds.PodUpdateOrAddIfNotExist(updatedPod) }, }, { @@ -416,21 +406,7 @@ func TestPods(t *testing.T) { Namespace: "default", }, } - ds.PodUpdateOrAddIfNotExist(incoming, pool) - }, - }, - { - name: "Change pool selector, resync required, should update", - existingPods: []*corev1.Pod{pod1, pod2}, - wantPods: []*corev1.Pod{pod1, pod2}, - op: func(ctx context.Context, ds Datastore) { - scheme := runtime.NewScheme() - cfg := config.GetConfigOrDie() - cli, err := client.New(cfg, client.Options{Scheme: scheme}) - if err != nil { - t.Fatalf("Unable to create ctrl runtime client") - } - ds.PodResyncAll(ctx, cli, resyncPool) + ds.PodUpdateOrAddIfNotExist(incoming) }, }, { @@ -455,7 +431,7 @@ func TestPods(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := NewDatastore(t.Context(), pmf) for _, pod := range test.existingPods { - ds.PodUpdateOrAddIfNotExist(pod, pool) + ds.PodUpdateOrAddIfNotExist(pod) } test.op(ctx, ds) From 56eb52fa5505921052e8ca34aa015b4d339d0134 Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Wed, 23 Apr 2025 04:27:47 +0300 Subject: [PATCH 18/20] few updates in datastore (#713) * few updates in datastore Signed-off-by: Nir Rozenbaum * PoolSet documentation Signed-off-by: Nir Rozenbaum * error phrasing Signed-off-by: Nir Rozenbaum * removed unused pool arg from PodUpdateOrAddIfNotExist Signed-off-by: Nir Rozenbaum * linter Signed-off-by: Nir Rozenbaum --------- Signed-off-by: Nir Rozenbaum --- pkg/epp/controller/pod_reconciler.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 5f1df10d7..2156623e5 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -28,7 +28,10 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/log" +<<<<<<< HEAD "sigs.k8s.io/controller-runtime/pkg/predicate" +======= +>>>>>>> 7792676 (few updates in datastore (#713)) "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" From 56d20e5bffdaf64deeee3068087eaacfc5dbd931 Mon Sep 17 00:00:00 2001 From: Nir Rozenbaum Date: Wed, 23 Apr 2025 04:27:47 +0300 Subject: [PATCH 19/20] few updates in datastore (#713) * few updates in datastore Signed-off-by: Nir Rozenbaum * PoolSet documentation Signed-off-by: Nir Rozenbaum * error phrasing Signed-off-by: Nir Rozenbaum * removed unused pool arg from PodUpdateOrAddIfNotExist Signed-off-by: Nir Rozenbaum * linter Signed-off-by: Nir Rozenbaum --------- Signed-off-by: Nir Rozenbaum --- pkg/epp/controller/pod_reconciler.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 2156623e5..5f1df10d7 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -28,10 +28,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/log" -<<<<<<< HEAD "sigs.k8s.io/controller-runtime/pkg/predicate" -======= ->>>>>>> 7792676 (few updates in datastore (#713)) "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" From 124369e9d682eb75d2158b7bd66dac148f79e296 Mon Sep 17 00:00:00 2001 From: Radhika Lakhtakia Date: Fri, 18 Apr 2025 18:42:46 +0000 Subject: [PATCH 20/20] Add unit test coverage for pod APIs under datastore/pkg --- pkg/epp/datastore/datastore_test.go | 34 +++++------------------------ 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 762aec2c0..b6466e6b2 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -30,8 +30,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/client/config" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -359,10 +357,6 @@ func TestMetrics(t *testing.T) { } func TestPods(t *testing.T) { - poolSelector := map[string]string{"app": "vllm_v1"} - pool := testutil.MakeInferencePool("pool"). - Namespace("default"). - Selector(poolSelector).ObjRef() updatedPod := &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: "pod1", @@ -371,10 +365,6 @@ func TestPods(t *testing.T) { NodeName: "node-1", }, } - resyncPoolSelector := map[string]string{"app": "llama3_8b"} - resyncPool := testutil.MakeInferencePool("pool"). - Namespace("default"). - Selector(resyncPoolSelector).ObjRef() tests := []struct { name string op func(ctx context.Context, ds Datastore) @@ -386,7 +376,7 @@ func TestPods(t *testing.T) { existingPods: []*corev1.Pod{}, wantPods: []*corev1.Pod{pod1}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(pod1, pool) + ds.PodUpdateOrAddIfNotExist(pod1) }, }, { @@ -394,7 +384,7 @@ func TestPods(t *testing.T) { existingPods: []*corev1.Pod{pod1}, wantPods: []*corev1.Pod{pod1, pod2}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(pod2, pool) + ds.PodUpdateOrAddIfNotExist(pod2) }, }, { @@ -402,7 +392,7 @@ func TestPods(t *testing.T) { existingPods: []*corev1.Pod{pod1}, wantPods: []*corev1.Pod{updatedPod}, op: func(ctx context.Context, ds Datastore) { - ds.PodUpdateOrAddIfNotExist(updatedPod, pool) + ds.PodUpdateOrAddIfNotExist(updatedPod) }, }, { @@ -416,21 +406,7 @@ func TestPods(t *testing.T) { Namespace: "default", }, } - ds.PodUpdateOrAddIfNotExist(incoming, pool) - }, - }, - { - name: "Change pool selector, resync required, should update", - existingPods: []*corev1.Pod{pod1, pod2}, - wantPods: []*corev1.Pod{pod1, pod2}, - op: func(ctx context.Context, ds Datastore) { - scheme := runtime.NewScheme() - cfg := config.GetConfigOrDie() - cli, err := client.New(cfg, client.Options{Scheme: scheme}) - if err != nil { - t.Fatalf("Unable to create ctrl runtime client") - } - ds.PodResyncAll(ctx, cli, resyncPool) + ds.PodUpdateOrAddIfNotExist(incoming) }, }, { @@ -455,7 +431,7 @@ func TestPods(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) ds := NewDatastore(t.Context(), pmf) for _, pod := range test.existingPods { - ds.PodUpdateOrAddIfNotExist(pod, pool) + ds.PodUpdateOrAddIfNotExist(pod) } test.op(ctx, ds)