From 8912c3147910fca6aded0406a5bc57b76fc5832a Mon Sep 17 00:00:00 2001 From: ahg-g Date: Sun, 16 Feb 2025 06:41:34 +0000 Subject: [PATCH 1/8] Removed the intermediate cache in provider, and consolidating all storage behind datastore. --- pkg/ext-proc/backend/datastore.go | 238 ++++++++----- pkg/ext-proc/backend/datastore_test.go | 6 +- pkg/ext-proc/backend/fake.go | 13 +- .../backend/inferencemodel_reconciler.go | 10 +- .../backend/inferencemodel_reconciler_test.go | 52 +-- .../backend/inferencepool_reconciler.go | 27 +- .../backend/inferencepool_reconciler_test.go | 156 +++++---- pkg/ext-proc/backend/pod_reconciler.go | 31 +- pkg/ext-proc/backend/pod_reconciler_test.go | 125 +++++-- pkg/ext-proc/backend/provider.go | 134 +++---- pkg/ext-proc/backend/provider_test.go | 77 ++-- pkg/ext-proc/backend/types.go | 23 +- pkg/ext-proc/backend/vllm/metrics.go | 11 +- pkg/ext-proc/handlers/request.go | 5 +- pkg/ext-proc/handlers/server.go | 24 +- pkg/ext-proc/health.go | 4 +- pkg/ext-proc/main.go | 16 +- pkg/ext-proc/scheduling/filter_test.go | 23 +- pkg/ext-proc/scheduling/scheduler.go | 27 +- pkg/ext-proc/server/runserver.go | 17 +- pkg/ext-proc/server/runserver_test.go | 2 +- pkg/ext-proc/test/benchmark/benchmark.go | 12 +- pkg/ext-proc/test/utils.go | 50 ++- pkg/ext-proc/util/testing/wrappers.go | 50 +++ test/integration/hermetic_test.go | 328 +++++++++--------- 25 files changed, 801 insertions(+), 660 deletions(-) create mode 100644 pkg/ext-proc/util/testing/wrappers.go diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index a75e7e433..05e0a9c38 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -10,136 +10,187 @@ import ( "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) -func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore { - store := &K8sDatastore{ - poolMu: sync.RWMutex{}, - InferenceModels: &sync.Map{}, - pods: &sync.Map{}, - } - for _, opt := range options { - opt(store) +// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) +type Datastore interface { + // InferencePool operations + PoolSet(pool *v1alpha1.InferencePool) + PoolGet() (*v1alpha1.InferencePool, error) + PoolHasSynced() bool + PoolLabelsMatch(podLabels map[string]string) bool + + // InferenceModel operations + ModelSet(infModel *v1alpha1.InferenceModel) + ModelGet(modelName string) (returnModel *v1alpha1.InferenceModel) + ModelDelete(modelName string) + + // PodMetrics operations + PodAddIfNotExist(pod *corev1.Pod) bool + PodUpdateMetricsIfExist(pm *PodMetrics) + PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) + PodDelete(namespacedName types.NamespacedName) + PodFlush(ctx context.Context, ctrlClient client.Client) + PodGetAll() []*PodMetrics + PodRange(f func(key, value any) bool) + PodDeleteAll() // This is only for testing. +} + +func NewDatastore() Datastore { + store := &datastore{ + poolMu: sync.RWMutex{}, + models: &sync.Map{}, + pods: &sync.Map{}, } return store } -// The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) -type K8sDatastore struct { +type datastore struct { // poolMu is used to synchronize access to the inferencePool. - poolMu sync.RWMutex - inferencePool *v1alpha1.InferencePool - InferenceModels *sync.Map - pods *sync.Map -} - -type K8sDatastoreOption func(*K8sDatastore) - -// WithPods can be used in tests to override the pods. -func WithPods(pods []*PodMetrics) K8sDatastoreOption { - return func(store *K8sDatastore) { - store.pods = &sync.Map{} - for _, pod := range pods { - store.pods.Store(pod.Pod, true) - } - } + poolMu sync.RWMutex + pool *v1alpha1.InferencePool + models *sync.Map + // key: types.NamespacedName, value: *PodMetrics + pods *sync.Map } -func (ds *K8sDatastore) setInferencePool(pool *v1alpha1.InferencePool) { +// /// InferencePool APIs /// +func (ds *datastore) PoolSet(pool *v1alpha1.InferencePool) { ds.poolMu.Lock() defer ds.poolMu.Unlock() - ds.inferencePool = pool + ds.pool = pool } -func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) { +func (ds *datastore) PoolGet() (*v1alpha1.InferencePool, error) { ds.poolMu.RLock() defer ds.poolMu.RUnlock() - if !ds.HasSynced() { + if !ds.PoolHasSynced() { return nil, errors.New("InferencePool is not initialized in data store") } - return ds.inferencePool, nil + return ds.pool, nil } -func (ds *K8sDatastore) GetPodIPs() []string { - var ips []string - ds.pods.Range(func(name, pod any) bool { - ips = append(ips, pod.(*corev1.Pod).Status.PodIP) - return true - }) - return ips +func (ds *datastore) PoolHasSynced() bool { + ds.poolMu.RLock() + defer ds.poolMu.RUnlock() + return ds.pool != nil +} + +func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { + poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector) + podSet := labels.Set(podLabels) + return poolSelector.Matches(podSet) } -func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) { - infModel, ok := s.InferenceModels.Load(modelName) +// /// InferenceModel APIs /// +func (ds *datastore) ModelSet(infModel *v1alpha1.InferenceModel) { + ds.models.Store(infModel.Spec.ModelName, infModel) +} + +func (ds *datastore) ModelGet(modelName string) (returnModel *v1alpha1.InferenceModel) { + infModel, ok := ds.models.Load(modelName) if ok { returnModel = infModel.(*v1alpha1.InferenceModel) } return } -// HasSynced returns true if InferencePool is set in the data store. -func (ds *K8sDatastore) HasSynced() bool { - ds.poolMu.RLock() - defer ds.poolMu.RUnlock() - return ds.inferencePool != nil +func (ds *datastore) ModelDelete(modelName string) { + ds.models.Delete(modelName) } -func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string { - var weights int32 - - source := rand.NewSource(rand.Int63()) - if seed > 0 { - source = rand.NewSource(seed) - } - r := rand.New(source) - for _, model := range model.Spec.TargetModels { - weights += *model.Weight +// /// Pods/endpoints APIs /// +func (ds *datastore) PodUpdateMetricsIfExist(pm *PodMetrics) { + if val, ok := ds.pods.Load(pm.NamespacedName); ok { + existing := val.(*PodMetrics) + existing.Metrics = pm.Metrics } - logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) - randomVal := r.Int31n(weights) - for _, model := range model.Spec.TargetModels { - if randomVal < *model.Weight { - return model.Name - } - randomVal -= *model.Weight +} + +func (ds *datastore) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) { + val, ok := ds.pods.Load(namespacedName) + if ok { + return val.(*PodMetrics), true } - return "" + return nil, false } -func IsCritical(model *v1alpha1.InferenceModel) bool { - if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical { +func (ds *datastore) PodGetAll() []*PodMetrics { + res := []*PodMetrics{} + fn := func(k, v any) bool { + res = append(res, v.(*PodMetrics)) return true } - return false + ds.pods.Range(fn) + return res } -func (ds *K8sDatastore) LabelsMatch(podLabels map[string]string) bool { - poolSelector := selectorFromInferencePoolSelector(ds.inferencePool.Spec.Selector) - podSet := labels.Set(podLabels) - return poolSelector.Matches(podSet) +func (ds *datastore) PodRange(f func(key, value any) bool) { + ds.pods.Range(f) +} + +func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { + ds.pods.Delete(namespacedName) +} + +func (ds *datastore) PodAddIfNotExist(pod *corev1.Pod) bool { + // new pod, add to the store for probing + pool, _ := ds.PoolGet() + new := &PodMetrics{ + NamespacedName: types.NamespacedName{ + Name: pod.Name, + Namespace: pod.Namespace, + }, + Address: pod.Status.PodIP + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)), + Metrics: Metrics{ + ActiveModels: make(map[string]int), + }, + } + if _, ok := ds.pods.Load(new.NamespacedName); !ok { + ds.pods.Store(new.NamespacedName, new) + return true + } + return false } -func (ds *K8sDatastore) flushPodsAndRefetch(ctx context.Context, ctrlClient client.Client, newServerPool *v1alpha1.InferencePool) { +func (ds *datastore) PodFlush(ctx context.Context, ctrlClient client.Client) { + // Pool must exist to invoke this function. + pool, _ := ds.PoolGet() podList := &corev1.PodList{} if err := ctrlClient.List(ctx, podList, &client.ListOptions{ - LabelSelector: selectorFromInferencePoolSelector(newServerPool.Spec.Selector), - Namespace: newServerPool.Namespace, + LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector), + Namespace: pool.Namespace, }); err != nil { log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients") + return } - ds.pods.Clear() - for _, k8sPod := range podList.Items { - pod := Pod{ - Name: k8sPod.Name, - Address: k8sPod.Status.PodIP + ":" + strconv.Itoa(int(newServerPool.Spec.TargetPortNumber)), + activePods := make(map[string]bool) + for _, pod := range podList.Items { + if podIsReady(&pod) { + activePods[pod.Name] = true + ds.PodAddIfNotExist(&pod) } - ds.pods.Store(pod, true) } + + // Remove pods that don't exist or not ready any more. + deleteFn := func(k, v any) bool { + pm := v.(*PodMetrics) + if exist := activePods[pm.NamespacedName.Name]; !exist { + ds.pods.Delete(pm.NamespacedName) + } + return true + } + ds.pods.Range(deleteFn) +} + +func (ds *datastore) PodDeleteAll() { + ds.pods.Clear() } func selectorFromInferencePoolSelector(selector map[v1alpha1.LabelKey]v1alpha1.LabelValue) labels.Selector { @@ -153,3 +204,32 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha1.LabelKey]v1alpha1.LabelV } return outMap } + +func RandomWeightedDraw(logger logr.Logger, model *v1alpha1.InferenceModel, seed int64) string { + var weights int32 + + source := rand.NewSource(rand.Int63()) + if seed > 0 { + source = rand.NewSource(seed) + } + r := rand.New(source) + for _, model := range model.Spec.TargetModels { + weights += *model.Weight + } + logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) + randomVal := r.Int31n(weights) + for _, model := range model.Spec.TargetModels { + if randomVal < *model.Weight { + return model.Name + } + randomVal -= *model.Weight + } + return "" +} + +func IsCritical(model *v1alpha1.InferenceModel) bool { + if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha1.Critical { + return true + } + return false +} diff --git a/pkg/ext-proc/backend/datastore_test.go b/pkg/ext-proc/backend/datastore_test.go index 9f74226a8..b44de0a54 100644 --- a/pkg/ext-proc/backend/datastore_test.go +++ b/pkg/ext-proc/backend/datastore_test.go @@ -32,13 +32,13 @@ func TestHasSynced(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - datastore := NewK8sDataStore() + datastore := NewDatastore() // Set the inference pool if tt.inferencePool != nil { - datastore.setInferencePool(tt.inferencePool) + datastore.PoolSet(tt.inferencePool) } // Check if the data store has been initialized - hasSynced := datastore.HasSynced() + hasSynced := datastore.PoolHasSynced() if hasSynced != tt.hasSynced { t.Errorf("IsInitialized() = %v, want %v", hasSynced, tt.hasSynced) } diff --git a/pkg/ext-proc/backend/fake.go b/pkg/ext-proc/backend/fake.go index 2c0757dbe..dfb520eff 100644 --- a/pkg/ext-proc/backend/fake.go +++ b/pkg/ext-proc/backend/fake.go @@ -3,22 +3,23 @@ package backend import ( "context" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) type FakePodMetricsClient struct { - Err map[Pod]error - Res map[Pod]*PodMetrics + Err map[types.NamespacedName]error + Res map[types.NamespacedName]*PodMetrics } -func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) { - if err, ok := f.Err[pod]; ok { +func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, existing *PodMetrics) (*PodMetrics, error) { + if err, ok := f.Err[existing.NamespacedName]; ok { return nil, err } - log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "pod", pod, "existing", existing, "new", f.Res[pod]) - return f.Res[pod], nil + log.FromContext(ctx).V(logutil.VERBOSE).Info("Fetching metrics for pod", "existing", existing, "new", f.Res[existing.NamespacedName]) + return f.Res[existing.NamespacedName], nil } type FakeDataStore struct { diff --git a/pkg/ext-proc/backend/inferencemodel_reconciler.go b/pkg/ext-proc/backend/inferencemodel_reconciler.go index 4959845cc..884e6b7ed 100644 --- a/pkg/ext-proc/backend/inferencemodel_reconciler.go +++ b/pkg/ext-proc/backend/inferencemodel_reconciler.go @@ -19,7 +19,7 @@ type InferenceModelReconciler struct { client.Client Scheme *runtime.Scheme Record record.EventRecorder - Datastore *K8sDatastore + Datastore Datastore PoolNamespacedName types.NamespacedName } @@ -36,14 +36,14 @@ func (c *InferenceModelReconciler) Reconcile(ctx context.Context, req ctrl.Reque if err := c.Get(ctx, req.NamespacedName, infModel); err != nil { if errors.IsNotFound(err) { loggerDefault.Info("InferenceModel not found. Removing from datastore since object must be deleted", "name", req.NamespacedName) - c.Datastore.InferenceModels.Delete(infModel.Spec.ModelName) + c.Datastore.ModelDelete(infModel.Spec.ModelName) return ctrl.Result{}, nil } loggerDefault.Error(err, "Unable to get InferenceModel", "name", req.NamespacedName) return ctrl.Result{}, err } else if !infModel.DeletionTimestamp.IsZero() { loggerDefault.Info("InferenceModel is marked for deletion. Removing from datastore", "name", req.NamespacedName) - c.Datastore.InferenceModels.Delete(infModel.Spec.ModelName) + c.Datastore.ModelDelete(infModel.Spec.ModelName) return ctrl.Result{}, nil } @@ -57,12 +57,12 @@ func (c *InferenceModelReconciler) updateDatastore(logger logr.Logger, infModel if infModel.Spec.PoolRef.Name == c.PoolNamespacedName.Name { loggerDefault.Info("Updating datastore", "poolRef", infModel.Spec.PoolRef, "serverPoolName", c.PoolNamespacedName) loggerDefault.Info("Adding/Updating InferenceModel", "modelName", infModel.Spec.ModelName) - c.Datastore.InferenceModels.Store(infModel.Spec.ModelName, infModel) + c.Datastore.ModelSet(infModel) return } loggerDefault.Info("Removing/Not adding InferenceModel", "modelName", infModel.Spec.ModelName) // If we get here. The model is not relevant to this pool, remove. - c.Datastore.InferenceModels.Delete(infModel.Spec.ModelName) + c.Datastore.ModelDelete(infModel.Spec.ModelName) } func (c *InferenceModelReconciler) SetupWithManager(mgr ctrl.Manager) error { diff --git a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go index 4e1958181..67872636e 100644 --- a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go +++ b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go @@ -51,14 +51,14 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { tests := []struct { name string - datastore *K8sDatastore + datastore *datastore incomingService *v1alpha1.InferenceModel wantInferenceModels *sync.Map }{ { name: "No Services registered; valid, new service incoming.", - datastore: &K8sDatastore{ - inferencePool: &v1alpha1.InferencePool{ + datastore: &datastore{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, }, @@ -67,15 +67,15 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { ResourceVersion: "Old and boring", }, }, - InferenceModels: &sync.Map{}, + models: &sync.Map{}, }, incomingService: infModel1, wantInferenceModels: populateServiceMap(infModel1), }, { name: "Removing existing service.", - datastore: &K8sDatastore{ - inferencePool: &v1alpha1.InferencePool{ + datastore: &datastore{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, }, @@ -84,15 +84,15 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { ResourceVersion: "Old and boring", }, }, - InferenceModels: populateServiceMap(infModel1), + models: populateServiceMap(infModel1), }, incomingService: infModel1Modified, wantInferenceModels: populateServiceMap(), }, { name: "Unrelated service, do nothing.", - datastore: &K8sDatastore{ - inferencePool: &v1alpha1.InferencePool{ + datastore: &datastore{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, }, @@ -101,7 +101,7 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { ResourceVersion: "Old and boring", }, }, - InferenceModels: populateServiceMap(infModel1), + models: populateServiceMap(infModel1), }, incomingService: &v1alpha1.InferenceModel{ Spec: v1alpha1.InferenceModelSpec{ @@ -116,8 +116,8 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { }, { name: "Add to existing", - datastore: &K8sDatastore{ - inferencePool: &v1alpha1.InferencePool{ + datastore: &datastore{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, }, @@ -126,7 +126,7 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { ResourceVersion: "Old and boring", }, }, - InferenceModels: populateServiceMap(infModel1), + models: populateServiceMap(infModel1), }, incomingService: infModel2, wantInferenceModels: populateServiceMap(infModel1, infModel2), @@ -136,11 +136,11 @@ func TestUpdateDatastore_InferenceModelReconciler(t *testing.T) { t.Run(test.name, func(t *testing.T) { reconciler := &InferenceModelReconciler{ Datastore: test.datastore, - PoolNamespacedName: types.NamespacedName{Name: test.datastore.inferencePool.Name}, + PoolNamespacedName: types.NamespacedName{Name: test.datastore.pool.Name}, } reconciler.updateDatastore(logger, test.incomingService) - if ok := mapsEqual(reconciler.Datastore.InferenceModels, test.wantInferenceModels); !ok { + if ok := mapsEqual(test.datastore.models, test.wantInferenceModels); !ok { t.Error("Maps are not equal") } }) @@ -156,9 +156,9 @@ func TestReconcile_ResourceNotFound(t *testing.T) { fakeClient := fake.NewClientBuilder().WithScheme(scheme).Build() // Create a minimal datastore. - datastore := &K8sDatastore{ - InferenceModels: &sync.Map{}, - inferencePool: &v1alpha1.InferencePool{ + datastore := &datastore{ + models: &sync.Map{}, + pool: &v1alpha1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool"}, }, } @@ -211,9 +211,9 @@ func TestReconcile_ModelMarkedForDeletion(t *testing.T) { fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(existingModel).Build() // Create a minimal datastore. - datastore := &K8sDatastore{ - InferenceModels: &sync.Map{}, - inferencePool: &v1alpha1.InferencePool{ + datastore := &datastore{ + models: &sync.Map{}, + pool: &v1alpha1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool"}, }, } @@ -242,7 +242,7 @@ func TestReconcile_ModelMarkedForDeletion(t *testing.T) { } // Verify that the datastore was not updated. - if _, ok := datastore.InferenceModels.Load(existingModel.Spec.ModelName); ok { + if infModel := datastore.ModelGet(existingModel.Spec.ModelName); infModel != nil { t.Errorf("expected datastore to not contain model %q", existingModel.Spec.ModelName) } } @@ -268,9 +268,9 @@ func TestReconcile_ResourceExists(t *testing.T) { fakeClient := fake.NewClientBuilder().WithScheme(scheme).WithObjects(existingModel).Build() // Create a minimal datastore. - datastore := &K8sDatastore{ - InferenceModels: &sync.Map{}, - inferencePool: &v1alpha1.InferencePool{ + datastore := &datastore{ + models: &sync.Map{}, + pool: &v1alpha1.InferencePool{ ObjectMeta: metav1.ObjectMeta{Name: "test-pool"}, }, } @@ -299,7 +299,7 @@ func TestReconcile_ResourceExists(t *testing.T) { } // Verify that the datastore was updated. - if _, ok := datastore.InferenceModels.Load(existingModel.Spec.ModelName); !ok { + if infModel := datastore.ModelGet(existingModel.Spec.ModelName); infModel == nil { t.Errorf("expected datastore to contain model %q", existingModel.Spec.ModelName) } } diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index e44a278ae..36a7a60c2 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -4,7 +4,6 @@ import ( "context" "reflect" - "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" @@ -24,7 +23,7 @@ type InferencePoolReconciler struct { Scheme *runtime.Scheme Record record.EventRecorder PoolNamespacedName types.NamespacedName - Datastore *K8sDatastore + Datastore Datastore } func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { @@ -40,23 +39,23 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques if err := c.Get(ctx, req.NamespacedName, serverPool); err != nil { loggerDefault.Error(err, "Unable to get InferencePool", "name", req.NamespacedName) return ctrl.Result{}, err - } - if c.Datastore.inferencePool == nil || !reflect.DeepEqual(serverPool.Spec.Selector, c.Datastore.inferencePool.Spec.Selector) { - c.updateDatastore(logger, serverPool) - c.Datastore.flushPodsAndRefetch(ctx, c.Client, serverPool) - } else { - c.updateDatastore(logger, serverPool) + + // TODO: Handle InferencePool deletions. Need to flush the datastore. + // TODO: Handle port updates, podMetrics should not be storing that as part of the address. } + c.updateDatastore(ctx, serverPool) + return ctrl.Result{}, nil } -func (c *InferencePoolReconciler) updateDatastore(logger logr.Logger, serverPool *v1alpha1.InferencePool) { - pool, _ := c.Datastore.getInferencePool() - if pool == nil || - serverPool.ObjectMeta.ResourceVersion != pool.ObjectMeta.ResourceVersion { - logger.V(logutil.DEFAULT).Info("Updating inference pool", "target", klog.KMetadata(&serverPool.ObjectMeta)) - c.Datastore.setInferencePool(serverPool) +func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool *v1alpha1.InferencePool) { + logger := log.FromContext(ctx) + oldPool, _ := c.Datastore.PoolGet() + c.Datastore.PoolSet(newPool) + if oldPool == nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { + logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "target", klog.KMetadata(&newPool.ObjectMeta)) + c.Datastore.PodFlush(ctx, c.Client) } } diff --git a/pkg/ext-proc/backend/inferencepool_reconciler_test.go b/pkg/ext-proc/backend/inferencepool_reconciler_test.go index 1da7d61b0..c1c700a0c 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler_test.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler_test.go @@ -1,88 +1,118 @@ package backend import ( - "reflect" + "context" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" + utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" ) var ( - pool1 = &v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{ - Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, - }, + selector_v1 = map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm_v1"} + selector_v2 = map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm_v2"} + pool1 = &v1alpha1.InferencePool{ ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool", - ResourceVersion: "50", + Name: "pool1", + Namespace: "pool1-ns", }, + Spec: v1alpha1.InferencePoolSpec{Selector: selector_v1}, } - // Different name, same RV doesn't really make sense, but helps with testing the - // updateStore impl which relies on the equality of RVs alone. - modPool1SameRV = &v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{ - Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, - }, + pool2 = &v1alpha1.InferencePool{ ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool-mod", - ResourceVersion: "50", + Name: "pool2", + Namespace: "pool2-ns", }, } - modPool1DiffRV = &v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{ - Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{"app": "vllm"}, - }, - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pool-mod", - ResourceVersion: "51", - }, + pods = []corev1.Pod{ + // Two ready pods matching pool1 + utiltesting.MakePod("pod1", "pool1-ns").Labels(stripLabelKeyAliasFromLabelMap(selector_v1)).ReadyCondition().Obj(), + utiltesting.MakePod("pod2", "pool1-ns").Labels(stripLabelKeyAliasFromLabelMap(selector_v1)).ReadyCondition().Obj(), + // A not ready pod matching pool1 + utiltesting.MakePod("pod3", "pool1-ns").Labels(stripLabelKeyAliasFromLabelMap(selector_v1)).Obj(), + // A pod not matching pool1 namespace + utiltesting.MakePod("pod4", "pool2-ns").Labels(stripLabelKeyAliasFromLabelMap(selector_v1)).ReadyCondition().Obj(), + // A ready pod matching pool1 with a new selector + utiltesting.MakePod("pod5", "pool1-ns").Labels(stripLabelKeyAliasFromLabelMap(selector_v2)).ReadyCondition().Obj(), } ) -func TestUpdateDatastore_InferencePoolReconciler(t *testing.T) { - logger := logutil.NewTestLogger() +func TestReconcile_InferencePoolReconciler(t *testing.T) { + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = v1alpha1.AddToScheme(scheme) - tests := []struct { - name string - datastore *K8sDatastore - incomingPool *v1alpha1.InferencePool - wantPool *v1alpha1.InferencePool - }{ - { - name: "InferencePool not set, should set InferencePool", - datastore: &K8sDatastore{}, - incomingPool: pool1.DeepCopy(), - wantPool: pool1, - }, - { - name: "InferencePool set, matching RVs, do nothing", - datastore: &K8sDatastore{ - inferencePool: pool1.DeepCopy(), - }, - incomingPool: modPool1SameRV.DeepCopy(), - wantPool: pool1, - }, - { - name: "InferencePool set, differing RVs, re-set InferencePool", - datastore: &K8sDatastore{ - inferencePool: pool1.DeepCopy(), - }, - incomingPool: modPool1DiffRV.DeepCopy(), - wantPool: modPool1DiffRV, - }, + // Create a fake client with the pool and the pods. + initialObjects := []client.Object{pool1, pool2} + for i := range pods { + initialObjects = append(initialObjects, &pods[i]) } + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(initialObjects...). + Build() + + // Create a request for the existing resource. + namespacedName := types.NamespacedName{Name: pool1.Name, Namespace: pool1.Namespace} + req := ctrl.Request{NamespacedName: namespacedName} + ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - inferencePoolReconciler := &InferencePoolReconciler{Datastore: test.datastore} - inferencePoolReconciler.updateDatastore(logger, test.incomingPool) + datastore := NewDatastore() + inferencePoolReconciler := &InferencePoolReconciler{PoolNamespacedName: namespacedName, Client: fakeClient, Datastore: datastore} + + // Step 1: Inception, only ready pods matching pool1 are added to the store. + if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { + t.Errorf("Unexpected InferencePool reconcile error: %v", err) + } + if diff := diffPool(datastore, pool1, []string{"pod1", "pod2"}); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } - gotPool := inferencePoolReconciler.Datastore.inferencePool - if !reflect.DeepEqual(gotPool, test.wantPool) { - t.Errorf("Unexpected InferencePool: want %#v, got: %#v", test.wantPool, gotPool) - } - }) + // Step 2: A reconcile on pool2 should not change anything. + if _, err := inferencePoolReconciler.Reconcile(ctx, ctrl.Request{NamespacedName: types.NamespacedName{Name: pool2.Name, Namespace: pool2.Namespace}}); err != nil { + t.Errorf("Unexpected InferencePool reconcile error: %v", err) + } + if diff := diffPool(datastore, pool1, []string{"pod1", "pod2"}); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } + + // Step 3: update the pool selector to include more pods + newPool1 := &v1alpha1.InferencePool{} + if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { + t.Errorf("Unexpected pool get error: %v", err) + } + newPool1.Spec.Selector = selector_v2 + if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil { + t.Errorf("Unexpected pool update error: %v", err) + } + + if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { + t.Errorf("Unexpected InferencePool reconcile error: %v", err) + } + if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } +} + +func diffPool(datastore Datastore, wantPool *v1alpha1.InferencePool, wantPods []string) string { + gotPool, _ := datastore.PoolGet() + if diff := cmp.Diff(wantPool, gotPool); diff != "" { + return diff + } + gotPods := []string{} + for _, pm := range datastore.PodGetAll() { + gotPods = append(gotPods, pm.NamespacedName.Name) } + return cmp.Diff(wantPods, gotPods, cmpopts.SortSlices(func(a, b string) bool { return a < b })) } diff --git a/pkg/ext-proc/backend/pod_reconciler.go b/pkg/ext-proc/backend/pod_reconciler.go index b914ea8d2..9bfe3dc89 100644 --- a/pkg/ext-proc/backend/pod_reconciler.go +++ b/pkg/ext-proc/backend/pod_reconciler.go @@ -2,29 +2,29 @@ package backend import ( "context" - "strconv" + "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) type PodReconciler struct { client.Client - Datastore *K8sDatastore + Datastore Datastore Scheme *runtime.Scheme Record record.EventRecorder } func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { logger := log.FromContext(ctx) - inferencePool, err := c.Datastore.getInferencePool() + inferencePool, err := c.Datastore.PoolGet() if err != nil { logger.V(logutil.TRACE).Info("Skipping reconciling Pod because the InferencePool is not available yet", "error", err) // When the inferencePool is initialized it lists the appropriate pods and populates the datastore, so no need to requeue. @@ -38,15 +38,14 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R pod := &corev1.Pod{} if err := c.Get(ctx, req.NamespacedName, pod); err != nil { if apierrors.IsNotFound(err) { - c.Datastore.pods.Delete(pod) + c.Datastore.PodDelete(req.NamespacedName) return ctrl.Result{}, nil } logger.V(logutil.DEFAULT).Error(err, "Unable to get pod", "name", req.NamespacedName) return ctrl.Result{}, err } - c.updateDatastore(pod, inferencePool) - + c.updateDatastore(logger, pod) return ctrl.Result{}, nil } @@ -56,15 +55,17 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(c) } -func (c *PodReconciler) updateDatastore(k8sPod *corev1.Pod, inferencePool *v1alpha1.InferencePool) { - pod := Pod{ - Name: k8sPod.Name, - Address: k8sPod.Status.PodIP + ":" + strconv.Itoa(int(inferencePool.Spec.TargetPortNumber)), - } - if !k8sPod.DeletionTimestamp.IsZero() || !c.Datastore.LabelsMatch(k8sPod.ObjectMeta.Labels) || !podIsReady(k8sPod) { - c.Datastore.pods.Delete(pod) +func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { + namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + if !pod.DeletionTimestamp.IsZero() || !c.Datastore.PoolLabelsMatch(pod.Labels) || !podIsReady(pod) { + logger.V(logutil.DEFAULT).Info("Pod removed or not added", "name", namespacedName) + c.Datastore.PodDelete(namespacedName) } else { - c.Datastore.pods.Store(pod, true) + if c.Datastore.PodAddIfNotExist(pod) { + logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) + } else { + logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) + } } } diff --git a/pkg/ext-proc/backend/pod_reconciler_test.go b/pkg/ext-proc/backend/pod_reconciler_test.go index 42d6d8e42..c2522fbba 100644 --- a/pkg/ext-proc/backend/pod_reconciler_test.go +++ b/pkg/ext-proc/backend/pod_reconciler_test.go @@ -1,33 +1,42 @@ package backend import ( + "context" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" ) var ( - basePod1 = Pod{Name: "pod1", Address: ":8000"} - basePod2 = Pod{Name: "pod2", Address: ":8000"} - basePod3 = Pod{Name: "pod3", Address: ":8000"} + basePod1 = &PodMetrics{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: ":8000"} + basePod2 = &PodMetrics{NamespacedName: types.NamespacedName{Name: "pod2"}, Address: ":8000"} + basePod3 = &PodMetrics{NamespacedName: types.NamespacedName{Name: "pod3"}, Address: ":8000"} ) func TestUpdateDatastore_PodReconciler(t *testing.T) { + now := metav1.Now() tests := []struct { name string - datastore *K8sDatastore + datastore Datastore incomingPod *corev1.Pod - wantPods []string + wantPods []types.NamespacedName + req *ctrl.Request }{ { name: "Add new pod", - datastore: &K8sDatastore{ + datastore: &datastore{ pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ @@ -52,13 +61,62 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []string{basePod1.Name, basePod2.Name, basePod3.Name}, + wantPods: []types.NamespacedName{basePod1.NamespacedName, basePod2.NamespacedName, basePod3.NamespacedName}, + }, + { + name: "Delete pod with DeletionTimestamp", + datastore: &datastore{ + pods: populateMap(basePod1, basePod2), + pool: &v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{ + TargetPortNumber: int32(8000), + Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ + "some-key": "some-val", + }, + }, + }, + }, + incomingPod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Labels: map[string]string{ + "some-key": "some-val", + }, + DeletionTimestamp: &now, + Finalizers: []string{"finalizer"}, + }, + Status: corev1.PodStatus{ + Conditions: []corev1.PodCondition{ + { + Type: corev1.PodReady, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + wantPods: []types.NamespacedName{basePod2.NamespacedName}, + }, + { + name: "Delete notfound pod", + datastore: &datastore{ + pods: populateMap(basePod1, basePod2), + pool: &v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{ + TargetPortNumber: int32(8000), + Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ + "some-key": "some-val", + }, + }, + }, + }, + req: &ctrl.Request{NamespacedName: types.NamespacedName{Name: "pod1"}}, + wantPods: []types.NamespacedName{basePod2.NamespacedName}, }, { name: "New pod, not ready, valid selector", - datastore: &K8sDatastore{ + datastore: &datastore{ pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ @@ -83,13 +141,13 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []string{basePod1.Name, basePod2.Name}, + wantPods: []types.NamespacedName{basePod1.NamespacedName, basePod2.NamespacedName}, }, { name: "Remove pod that does not match selector", - datastore: &K8sDatastore{ + datastore: &datastore{ pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ @@ -114,13 +172,13 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []string{basePod2.Name}, + wantPods: []types.NamespacedName{basePod2.NamespacedName}, }, { name: "Remove pod that is not ready", - datastore: &K8sDatastore{ + datastore: &datastore{ pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ + pool: &v1alpha1.InferencePool{ Spec: v1alpha1.InferencePoolSpec{ TargetPortNumber: int32(8000), Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ @@ -145,22 +203,41 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []string{basePod2.Name}, + wantPods: []types.NamespacedName{basePod2.NamespacedName}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - podReconciler := &PodReconciler{Datastore: test.datastore} - podReconciler.updateDatastore(test.incomingPod, test.datastore.inferencePool) - var gotPods []string - test.datastore.pods.Range(func(k, v any) bool { - pod := k.(Pod) + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + initialObjects := []client.Object{} + if test.incomingPod != nil { + initialObjects = append(initialObjects, test.incomingPod) + } + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(initialObjects...). + Build() + + podReconciler := &PodReconciler{Client: fakeClient, Datastore: test.datastore} + namespacedName := types.NamespacedName{Name: test.incomingPod.Name, Namespace: test.incomingPod.Namespace} + if test.req == nil { + test.req = &ctrl.Request{NamespacedName: namespacedName} + } + if _, err := podReconciler.Reconcile(context.Background(), *test.req); err != nil { + t.Errorf("Unexpected InferencePool reconcile error: %v", err) + } + + var gotPods []types.NamespacedName + test.datastore.PodRange(func(k, v any) bool { + pod := v.(*PodMetrics) if v != nil { - gotPods = append(gotPods, pod.Name) + gotPods = append(gotPods, pod.NamespacedName) } return true }) - if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b string) bool { return a < b })) { + if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b types.NamespacedName) bool { return a.String() < b.String() })) { t.Errorf("got (%v) != want (%v);", gotPods, test.wantPods) } }) diff --git a/pkg/ext-proc/backend/provider.go b/pkg/ext-proc/backend/provider.go index ce7389864..7e55947da 100644 --- a/pkg/ext-proc/backend/provider.go +++ b/pkg/ext-proc/backend/provider.go @@ -8,6 +8,7 @@ import ( "github.com/go-logr/logr" "go.uber.org/multierr" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) @@ -16,72 +17,38 @@ const ( fetchMetricsTimeout = 5 * time.Second ) -func NewProvider(pmc PodMetricsClient, datastore *K8sDatastore) *Provider { +func NewProvider(pmc PodMetricsClient, datastore Datastore) *Provider { p := &Provider{ - podMetrics: sync.Map{}, - pmc: pmc, - datastore: datastore, + pmc: pmc, + datastore: datastore, } return p } // Provider provides backend pods and information such as metrics. type Provider struct { - // key: Pod, value: *PodMetrics - podMetrics sync.Map - pmc PodMetricsClient - datastore *K8sDatastore + pmc PodMetricsClient + datastore Datastore } type PodMetricsClient interface { - FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) + FetchMetrics(ctx context.Context, existing *PodMetrics) (*PodMetrics, error) } -func (p *Provider) AllPodMetrics() []*PodMetrics { - res := []*PodMetrics{} - fn := func(k, v any) bool { - res = append(res, v.(*PodMetrics)) - return true - } - p.podMetrics.Range(fn) - return res -} - -func (p *Provider) UpdatePodMetrics(pod Pod, pm *PodMetrics) { - p.podMetrics.Store(pod, pm) -} - -func (p *Provider) GetPodMetrics(pod Pod) (*PodMetrics, bool) { - val, ok := p.podMetrics.Load(pod) - if ok { - return val.(*PodMetrics), true - } - return nil, false -} - -func (p *Provider) Init(logger logr.Logger, refreshPodsInterval, refreshMetricsInterval, refreshPrometheusMetricsInterval time.Duration) error { - p.refreshPodsOnce() - - if err := p.refreshMetricsOnce(logger); err != nil { - logger.Error(err, "Failed to init metrics") - } - - logger.Info("Initialized pods and metrics", "metrics", p.AllPodMetrics()) - - // periodically refresh pods - go func() { - for { - time.Sleep(refreshPodsInterval) - p.refreshPodsOnce() - } - }() - +func (p *Provider) Init(ctx context.Context, refreshMetricsInterval, refreshPrometheusMetricsInterval time.Duration) error { // periodically refresh metrics + logger := log.FromContext(ctx) go func() { for { - time.Sleep(refreshMetricsInterval) - if err := p.refreshMetricsOnce(logger); err != nil { - logger.V(logutil.DEFAULT).Error(err, "Failed to refresh metrics") + select { + case <-ctx.Done(): + logger.V(logutil.DEFAULT).Info("Shutting down metrics prober") + return + default: + time.Sleep(refreshMetricsInterval) + if err := p.refreshMetricsOnce(logger); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Failed to refresh metrics") + } } } }() @@ -89,8 +56,14 @@ func (p *Provider) Init(logger logr.Logger, refreshPodsInterval, refreshMetricsI // Periodically flush prometheus metrics for inference pool go func() { for { - time.Sleep(refreshPrometheusMetricsInterval) - p.flushPrometheusMetricsOnce(logger) + select { + case <-ctx.Done(): + logger.V(logutil.DEFAULT).Info("Shutting down prometheus metrics thread") + return + default: + time.Sleep(refreshPrometheusMetricsInterval) + p.flushPrometheusMetricsOnce(logger) + } } }() @@ -98,8 +71,14 @@ func (p *Provider) Init(logger logr.Logger, refreshPodsInterval, refreshMetricsI if logger := logger.V(logutil.DEBUG); logger.Enabled() { go func() { for { - time.Sleep(5 * time.Second) - logger.Info("Current Pods and metrics gathered", "metrics", p.AllPodMetrics()) + select { + case <-ctx.Done(): + logger.V(logutil.DEFAULT).Info("Shutting down metrics logger thread") + return + default: + time.Sleep(5 * time.Second) + logger.Info("Current Pods and metrics gathered", "metrics", p.datastore.PodGetAll()) + } } }() } @@ -107,36 +86,6 @@ func (p *Provider) Init(logger logr.Logger, refreshPodsInterval, refreshMetricsI return nil } -// refreshPodsOnce lists pods and updates keys in the podMetrics map. -// Note this function doesn't update the PodMetrics value, it's done separately. -func (p *Provider) refreshPodsOnce() { - // merge new pods with cached ones. - // add new pod to the map - addNewPods := func(k, v any) bool { - pod := k.(Pod) - if _, ok := p.podMetrics.Load(pod); !ok { - new := &PodMetrics{ - Pod: pod, - Metrics: Metrics{ - ActiveModels: make(map[string]int), - }, - } - p.podMetrics.Store(pod, new) - } - return true - } - // remove pods that don't exist any more. - mergeFn := func(k, v any) bool { - pod := k.(Pod) - if _, ok := p.datastore.pods.Load(pod); !ok { - p.podMetrics.Delete(pod) - } - return true - } - p.podMetrics.Range(mergeFn) - p.datastore.pods.Range(addNewPods) -} - func (p *Provider) refreshMetricsOnce(logger logr.Logger) error { loggerTrace := logger.V(logutil.TRACE) ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout) @@ -151,22 +100,21 @@ func (p *Provider) refreshMetricsOnce(logger logr.Logger) error { errCh := make(chan error) processOnePod := func(key, value any) bool { loggerTrace.Info("Pod and metric being processed", "pod", key, "metric", value) - pod := key.(Pod) existing := value.(*PodMetrics) wg.Add(1) go func() { defer wg.Done() - updated, err := p.pmc.FetchMetrics(ctx, pod, existing) + updated, err := p.pmc.FetchMetrics(ctx, existing) if err != nil { - errCh <- fmt.Errorf("failed to parse metrics from %s: %v", pod, err) + errCh <- fmt.Errorf("failed to parse metrics from %s: %v", existing.NamespacedName, err) return } - p.UpdatePodMetrics(pod, updated) - loggerTrace.Info("Updated metrics for pod", "pod", pod, "metrics", updated.Metrics) + p.datastore.PodUpdateMetricsIfExist(updated) + loggerTrace.Info("Updated metrics for pod", "pod", updated.NamespacedName, "metrics", updated.Metrics) }() return true } - p.podMetrics.Range(processOnePod) + p.datastore.PodRange(processOnePod) // Wait for metric collection for all pods to complete and close the error channel in a // goroutine so this is unblocking, allowing the code to proceed to the error collection code @@ -188,7 +136,7 @@ func (p *Provider) refreshMetricsOnce(logger logr.Logger) error { func (p *Provider) flushPrometheusMetricsOnce(logger logr.Logger) { logger.V(logutil.DEBUG).Info("Flushing Prometheus Metrics") - pool, _ := p.datastore.getInferencePool() + pool, _ := p.datastore.PoolGet() if pool == nil { // No inference pool or not initialize. return @@ -197,7 +145,7 @@ func (p *Provider) flushPrometheusMetricsOnce(logger logr.Logger) { var kvCacheTotal float64 var queueTotal int - podMetrics := p.AllPodMetrics() + podMetrics := p.datastore.PodGetAll() if len(podMetrics) == 0 { return } diff --git a/pkg/ext-proc/backend/provider_test.go b/pkg/ext-proc/backend/provider_test.go index 955750463..d7d047802 100644 --- a/pkg/ext-proc/backend/provider_test.go +++ b/pkg/ext-proc/backend/provider_test.go @@ -1,19 +1,20 @@ package backend import ( - "errors" "sync" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "k8s.io/apimachinery/pkg/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) var ( pod1 = &PodMetrics{ - Pod: Pod{Name: "pod1"}, + NamespacedName: types.NamespacedName{ + Name: "pod1", + }, Metrics: Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -25,7 +26,9 @@ var ( }, } pod2 = &PodMetrics{ - Pod: Pod{Name: "pod2"}, + NamespacedName: types.NamespacedName{ + Name: "pod2", + }, Metrics: Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.2, @@ -44,48 +47,36 @@ func TestProvider(t *testing.T) { tests := []struct { name string pmc PodMetricsClient - datastore *K8sDatastore - initErr bool + datastore Datastore want []*PodMetrics }{ - { - name: "Init success", - datastore: &K8sDatastore{ - pods: populateMap(pod1.Pod, pod2.Pod), - }, - pmc: &FakePodMetricsClient{ - Res: map[Pod]*PodMetrics{ - pod1.Pod: pod1, - pod2.Pod: pod2, - }, - }, - want: []*PodMetrics{pod1, pod2}, - }, { name: "Fetch metrics error", pmc: &FakePodMetricsClient{ - Err: map[Pod]error{ - pod2.Pod: errors.New("injected error"), - }, - Res: map[Pod]*PodMetrics{ - pod1.Pod: pod1, + // Err: map[string]error{ + // pod2.Name: errors.New("injected error"), + // }, + Res: map[types.NamespacedName]*PodMetrics{ + pod1.NamespacedName: pod1, + pod2.NamespacedName: pod2, }, }, - datastore: &K8sDatastore{ - pods: populateMap(pod1.Pod, pod2.Pod), + datastore: &datastore{ + pods: populateMap(pod1, pod2), }, want: []*PodMetrics{ pod1, - // Failed to fetch pod2 metrics so it remains the default values. - { - Pod: Pod{Name: "pod2"}, - Metrics: Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0, - MaxActiveModels: 0, - ActiveModels: map[string]int{}, - }, - }, + pod2, + // // Failed to fetch pod2 metrics so it remains the default values. + // { + // Name: "pod2", + // Metrics: Metrics{ + // WaitingQueueSize: 0, + // KVCacheUsagePercent: 0, + // MaxActiveModels: 0, + // ActiveModels: map[string]int{}, + // }, + // }, }, }, } @@ -93,11 +84,11 @@ func TestProvider(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { p := NewProvider(test.pmc, test.datastore) - err := p.Init(logger, time.Millisecond, time.Millisecond, time.Millisecond) - if test.initErr != (err != nil) { - t.Fatalf("Unexpected error, got: %v, want: %v", err, test.initErr) - } - metrics := p.AllPodMetrics() + // if err := p.refreshMetricsOnce(logger); err != nil { + // t.Fatalf("Unexpected error: %v", err) + // } + _ = p.refreshMetricsOnce(logger) + metrics := test.datastore.PodGetAll() lessFunc := func(a, b *PodMetrics) bool { return a.String() < b.String() } @@ -108,10 +99,10 @@ func TestProvider(t *testing.T) { } } -func populateMap(pods ...Pod) *sync.Map { +func populateMap(pods ...*PodMetrics) *sync.Map { newMap := &sync.Map{} for _, pod := range pods { - newMap.Store(pod, true) + newMap.Store(pod.NamespacedName, pod) } return newMap } diff --git a/pkg/ext-proc/backend/types.go b/pkg/ext-proc/backend/types.go index 7e399fedc..053c80d28 100644 --- a/pkg/ext-proc/backend/types.go +++ b/pkg/ext-proc/backend/types.go @@ -1,18 +1,11 @@ // Package backend is a library to interact with backend model servers such as probing metrics. package backend -import "fmt" +import ( + "fmt" -type PodSet map[Pod]bool - -type Pod struct { - Name string - Address string -} - -func (p Pod) String() string { - return p.Name + ":" + p.Address -} + "k8s.io/apimachinery/pkg/types" +) type Metrics struct { // ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU. @@ -26,12 +19,13 @@ type Metrics struct { } type PodMetrics struct { - Pod + NamespacedName types.NamespacedName + Address string Metrics } func (pm *PodMetrics) String() string { - return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics) + return fmt.Sprintf("Pod: %+v; Address: %+v; Metrics: %+v", pm.NamespacedName, pm.Address, pm.Metrics) } func (pm *PodMetrics) Clone() *PodMetrics { @@ -40,7 +34,8 @@ func (pm *PodMetrics) Clone() *PodMetrics { cm[k] = v } clone := &PodMetrics{ - Pod: pm.Pod, + NamespacedName: pm.NamespacedName, + Address: pm.Address, Metrics: Metrics{ ActiveModels: cm, RunningQueueSize: pm.RunningQueueSize, diff --git a/pkg/ext-proc/backend/vllm/metrics.go b/pkg/ext-proc/backend/vllm/metrics.go index 4558a6642..3737425dd 100644 --- a/pkg/ext-proc/backend/vllm/metrics.go +++ b/pkg/ext-proc/backend/vllm/metrics.go @@ -38,7 +38,6 @@ type PodMetricsClientImpl struct{} // FetchMetrics fetches metrics from a given pod. func (p *PodMetricsClientImpl) FetchMetrics( ctx context.Context, - pod backend.Pod, existing *backend.PodMetrics, ) (*backend.PodMetrics, error) { logger := log.FromContext(ctx) @@ -46,7 +45,7 @@ func (p *PodMetricsClientImpl) FetchMetrics( // Currently the metrics endpoint is hard-coded, which works with vLLM. // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16): Consume this from InferencePool config. - url := fmt.Sprintf("http://%s/metrics", pod.Address) + url := fmt.Sprintf("http://%s/metrics", existing.Address) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { loggerDefault.Error(err, "Failed create HTTP request", "method", http.MethodGet, "url", url) @@ -54,16 +53,16 @@ func (p *PodMetricsClientImpl) FetchMetrics( } resp, err := http.DefaultClient.Do(req) if err != nil { - loggerDefault.Error(err, "Failed to fetch metrics", "pod", pod) - return nil, fmt.Errorf("failed to fetch metrics from %s: %w", pod, err) + loggerDefault.Error(err, "Failed to fetch metrics", "pod", existing.NamespacedName) + return nil, fmt.Errorf("failed to fetch metrics from %s: %w", existing.NamespacedName, err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { - loggerDefault.Error(nil, "Unexpected status code returned", "pod", pod, "statusCode", resp.StatusCode) - return nil, fmt.Errorf("unexpected status code from %s: %v", pod, resp.StatusCode) + loggerDefault.Error(nil, "Unexpected status code returned", "pod", existing.NamespacedName, "statusCode", resp.StatusCode) + return nil, fmt.Errorf("unexpected status code from %s: %v", existing.NamespacedName, resp.StatusCode) } parser := expfmt.TextParser{} diff --git a/pkg/ext-proc/handlers/request.go b/pkg/ext-proc/handlers/request.go index 8ce2956f8..def75c2fc 100644 --- a/pkg/ext-proc/handlers/request.go +++ b/pkg/ext-proc/handlers/request.go @@ -48,7 +48,7 @@ func (s *Server) HandleRequestBody( // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. // This might be a security risk in the future where adapters not registered in the InferenceModel // are able to be requested by using their distinct name. - modelObj := s.datastore.FetchModelData(model) + modelObj := s.datastore.ModelGet(model) if modelObj == nil { return nil, fmt.Errorf("error finding a model object in InferenceModel for input %v", model) } @@ -88,7 +88,8 @@ func (s *Server) HandleRequestBody( reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel reqCtx.RequestSize = len(v.RequestBody.Body) - reqCtx.TargetPod = targetPod + reqCtx.TargetPod = targetPod.NamespacedName.String() + reqCtx.TargetPodAddress = targetPod.Address // Insert target endpoint to instruct Envoy to route requests to the specified target pod. headers := []*configPb.HeaderValueOption{ diff --git a/pkg/ext-proc/handlers/server.go b/pkg/ext-proc/handlers/server.go index 6be747dac..047331679 100644 --- a/pkg/ext-proc/handlers/server.go +++ b/pkg/ext-proc/handlers/server.go @@ -11,17 +11,15 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/scheduling" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) -func NewServer(pp PodProvider, scheduler Scheduler, targetEndpointKey string, datastore ModelDataStore) *Server { +func NewServer(scheduler Scheduler, targetEndpointKey string, datastore backend.Datastore) *Server { return &Server{ scheduler: scheduler, - podProvider: pp, targetEndpointKey: targetEndpointKey, datastore: datastore, } @@ -30,26 +28,15 @@ func NewServer(pp PodProvider, scheduler Scheduler, targetEndpointKey string, da // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto type Server struct { - scheduler Scheduler - podProvider PodProvider + scheduler Scheduler // The key of the header to specify the target pod address. This value needs to match Envoy // configuration. targetEndpointKey string - datastore ModelDataStore + datastore backend.Datastore } type Scheduler interface { - Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod backend.Pod, err error) -} - -// PodProvider is an interface to provide set of pods in the backend and information such as metrics. -type PodProvider interface { - GetPodMetrics(pod backend.Pod) (*backend.PodMetrics, bool) - UpdatePodMetrics(pod backend.Pod, pm *backend.PodMetrics) -} - -type ModelDataStore interface { - FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) + Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod backend.PodMetrics, err error) } func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { @@ -140,7 +127,8 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { // RequestContext stores context information during the life time of an HTTP request. type RequestContext struct { - TargetPod backend.Pod + TargetPod string + TargetPodAddress string Model string ResolvedTargetModel string RequestReceivedTimestamp time.Time diff --git a/pkg/ext-proc/health.go b/pkg/ext-proc/health.go index 8b684d39f..59aec348c 100644 --- a/pkg/ext-proc/health.go +++ b/pkg/ext-proc/health.go @@ -13,11 +13,11 @@ import ( type healthServer struct { logger logr.Logger - datastore *backend.K8sDatastore + datastore backend.Datastore } func (s *healthServer) Check(ctx context.Context, in *healthPb.HealthCheckRequest) (*healthPb.HealthCheckResponse, error) { - if !s.datastore.HasSynced() { + if !s.datastore.PoolHasSynced() { s.logger.V(logutil.VERBOSE).Info("gRPC health check not serving", "service", in.Service) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_NOT_SERVING}, nil } diff --git a/pkg/ext-proc/main.go b/pkg/ext-proc/main.go index ba593d7da..8e5886739 100644 --- a/pkg/ext-proc/main.go +++ b/pkg/ext-proc/main.go @@ -59,10 +59,6 @@ var ( "poolNamespace", runserver.DefaultPoolNamespace, "Namespace of the InferencePool this Endpoint Picker is associated with.") - refreshPodsInterval = flag.Duration( - "refreshPodsInterval", - runserver.DefaultRefreshPodsInterval, - "interval to refresh pods") refreshMetricsInterval = flag.Duration( "refreshMetricsInterval", runserver.DefaultRefreshMetricsInterval, @@ -115,8 +111,6 @@ func run() error { }) setupLog.Info("Flags processed", "flags", flags) - datastore := backend.NewK8sDataStore() - // Init runtime. cfg, err := ctrl.GetConfig() if err != nil { @@ -131,17 +125,19 @@ func run() error { } // Setup runner. + datastore := backend.NewDatastore() + provider := backend.NewProvider(&vllm.PodMetricsClientImpl{}, datastore) serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, TargetEndpointKey: *targetEndpointKey, PoolName: *poolName, PoolNamespace: *poolNamespace, - RefreshPodsInterval: *refreshPodsInterval, RefreshMetricsInterval: *refreshMetricsInterval, RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, Datastore: datastore, SecureServing: *secureServing, CertPath: *certPath, + Provider: provider, } if err := serverRunner.SetupWithManager(mgr); err != nil { setupLog.Error(err, "Failed to setup ext-proc server") @@ -154,9 +150,7 @@ func run() error { } // Register ext-proc server. - if err := mgr.Add(serverRunner.AsRunnable( - ctrl.Log.WithName("ext-proc"), datastore, &vllm.PodMetricsClientImpl{}, - )); err != nil { + if err := mgr.Add(serverRunner.AsRunnable(ctrl.Log.WithName("ext-proc"))); err != nil { setupLog.Error(err, "Failed to register ext-proc server") return err } @@ -195,7 +189,7 @@ func initLogging(opts *zap.Options) { } // registerHealthServer adds the Health gRPC server as a Runnable to the given manager. -func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds *backend.K8sDatastore, port int) error { +func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds backend.Datastore, port int) error { srv := grpc.NewServer() healthPb.RegisterHealthServer(srv, &healthServer{ logger: logger, diff --git a/pkg/ext-proc/scheduling/filter_test.go b/pkg/ext-proc/scheduling/filter_test.go index ee1a8c331..44f203cc1 100644 --- a/pkg/ext-proc/scheduling/filter_test.go +++ b/pkg/ext-proc/scheduling/filter_test.go @@ -6,6 +6,7 @@ import ( "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) @@ -40,7 +41,7 @@ func TestFilter(t *testing.T) { // model being active, and has low KV cache. input: []*backend.PodMetrics{ { - Pod: backend.Pod{Name: "pod1"}, + NamespacedName: types.NamespacedName{Name: "pod1"}, Metrics: backend.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -52,7 +53,7 @@ func TestFilter(t *testing.T) { }, }, { - Pod: backend.Pod{Name: "pod2"}, + NamespacedName: types.NamespacedName{Name: "pod2"}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, @@ -64,7 +65,7 @@ func TestFilter(t *testing.T) { }, }, { - Pod: backend.Pod{Name: "pod3"}, + NamespacedName: types.NamespacedName{Name: "pod3"}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, @@ -77,7 +78,7 @@ func TestFilter(t *testing.T) { }, output: []*backend.PodMetrics{ { - Pod: backend.Pod{Name: "pod2"}, + NamespacedName: types.NamespacedName{Name: "pod2"}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, @@ -101,7 +102,7 @@ func TestFilter(t *testing.T) { // pod1 will be picked because it has capacity for the sheddable request. input: []*backend.PodMetrics{ { - Pod: backend.Pod{Name: "pod1"}, + NamespacedName: types.NamespacedName{Name: "pod1"}, Metrics: backend.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -113,7 +114,7 @@ func TestFilter(t *testing.T) { }, }, { - Pod: backend.Pod{Name: "pod2"}, + NamespacedName: types.NamespacedName{Name: "pod2"}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, @@ -125,7 +126,7 @@ func TestFilter(t *testing.T) { }, }, { - Pod: backend.Pod{Name: "pod3"}, + NamespacedName: types.NamespacedName{Name: "pod3"}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, @@ -138,7 +139,7 @@ func TestFilter(t *testing.T) { }, output: []*backend.PodMetrics{ { - Pod: backend.Pod{Name: "pod1"}, + NamespacedName: types.NamespacedName{Name: "pod1"}, Metrics: backend.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -163,7 +164,7 @@ func TestFilter(t *testing.T) { // dropped. input: []*backend.PodMetrics{ { - Pod: backend.Pod{Name: "pod1"}, + NamespacedName: types.NamespacedName{Name: "pod1"}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.9, @@ -175,7 +176,7 @@ func TestFilter(t *testing.T) { }, }, { - Pod: backend.Pod{Name: "pod2"}, + NamespacedName: types.NamespacedName{Name: "pod2"}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.85, @@ -187,7 +188,7 @@ func TestFilter(t *testing.T) { }, }, { - Pod: backend.Pod{Name: "pod3"}, + NamespacedName: types.NamespacedName{Name: "pod3"}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.85, diff --git a/pkg/ext-proc/scheduling/scheduler.go b/pkg/ext-proc/scheduling/scheduler.go index 16cf90b87..354bd39cb 100644 --- a/pkg/ext-proc/scheduling/scheduler.go +++ b/pkg/ext-proc/scheduling/scheduler.go @@ -93,34 +93,29 @@ var ( } ) -func NewScheduler(pmp PodMetricsProvider) *Scheduler { +func NewScheduler(datastore backend.Datastore) *Scheduler { return &Scheduler{ - podMetricsProvider: pmp, - filter: defaultFilter, + datastore: datastore, + filter: defaultFilter, } } type Scheduler struct { - podMetricsProvider PodMetricsProvider - filter Filter -} - -// PodMetricsProvider is an interface to provide set of pods in the backend and information such as -// metrics. -type PodMetricsProvider interface { - AllPodMetrics() []*backend.PodMetrics + datastore backend.Datastore + filter Filter } // Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backend.Pod, err error) { +func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backend.PodMetrics, err error) { logger := log.FromContext(ctx).WithValues("request", req) - logger.V(logutil.VERBOSE).Info("Scheduling a request", "metrics", s.podMetricsProvider.AllPodMetrics()) - pods, err := s.filter.Filter(logger, req, s.podMetricsProvider.AllPodMetrics()) + podMetrics := s.datastore.PodGetAll() + logger.V(logutil.VERBOSE).Info("Scheduling a request", "metrics", podMetrics) + pods, err := s.filter.Filter(logger, req, podMetrics) if err != nil || len(pods) == 0 { - return backend.Pod{}, fmt.Errorf( + return backend.PodMetrics{}, fmt.Errorf( "failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) } logger.V(logutil.VERBOSE).Info("Selecting a random pod from the candidates", "candidatePods", pods) i := rand.Intn(len(pods)) - return pods[i].Pod, nil + return *pods[i], nil } diff --git a/pkg/ext-proc/server/runserver.go b/pkg/ext-proc/server/runserver.go index fb9741d24..073c30df1 100644 --- a/pkg/ext-proc/server/runserver.go +++ b/pkg/ext-proc/server/runserver.go @@ -31,10 +31,10 @@ type ExtProcServerRunner struct { TargetEndpointKey string PoolName string PoolNamespace string - RefreshPodsInterval time.Duration RefreshMetricsInterval time.Duration RefreshPrometheusMetricsInterval time.Duration - Datastore *backend.K8sDatastore + Datastore backend.Datastore + Provider *backend.Provider SecureServing bool CertPath string } @@ -45,7 +45,6 @@ const ( DefaultTargetEndpointKey = "x-gateway-destination-endpoint" // default for --targetEndpointKey DefaultPoolName = "" // required but no default DefaultPoolNamespace = "default" // default for --poolNamespace - DefaultRefreshPodsInterval = 10 * time.Second // default for --refreshPodsInterval DefaultRefreshMetricsInterval = 50 * time.Millisecond // default for --refreshMetricsInterval DefaultRefreshPrometheusMetricsInterval = 5 * time.Second // default for --refreshPrometheusMetricsInterval DefaultSecureServing = true // default for --secureServing @@ -57,7 +56,6 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { TargetEndpointKey: DefaultTargetEndpointKey, PoolName: DefaultPoolName, PoolNamespace: DefaultPoolNamespace, - RefreshPodsInterval: DefaultRefreshPodsInterval, RefreshMetricsInterval: DefaultRefreshMetricsInterval, RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, SecureServing: DefaultSecureServing, @@ -107,15 +105,10 @@ func (r *ExtProcServerRunner) SetupWithManager(mgr ctrl.Manager) error { // AsRunnable returns a Runnable that can be used to start the ext-proc gRPC server. // The runnable implements LeaderElectionRunnable with leader election disabled. -func (r *ExtProcServerRunner) AsRunnable( - logger logr.Logger, - podDatastore *backend.K8sDatastore, - podMetricsClient backend.PodMetricsClient, -) manager.Runnable { +func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { return runnable.NoLeaderElection(manager.RunnableFunc(func(ctx context.Context) error { // Initialize backend provider - pp := backend.NewProvider(podMetricsClient, podDatastore) - if err := pp.Init(logger.WithName("provider"), r.RefreshPodsInterval, r.RefreshMetricsInterval, r.RefreshPrometheusMetricsInterval); err != nil { + if err := r.Provider.Init(ctx, r.RefreshMetricsInterval, r.RefreshPrometheusMetricsInterval); err != nil { logger.Error(err, "Failed to initialize backend provider") return err } @@ -145,7 +138,7 @@ func (r *ExtProcServerRunner) AsRunnable( } extProcPb.RegisterExternalProcessorServer( srv, - handlers.NewServer(pp, scheduling.NewScheduler(pp), r.TargetEndpointKey, r.Datastore), + handlers.NewServer(scheduling.NewScheduler(r.Datastore), r.TargetEndpointKey, r.Datastore), ) // Forward to the gRPC runnable. diff --git a/pkg/ext-proc/server/runserver_test.go b/pkg/ext-proc/server/runserver_test.go index 1badb8fd9..32af2cd80 100644 --- a/pkg/ext-proc/server/runserver_test.go +++ b/pkg/ext-proc/server/runserver_test.go @@ -11,7 +11,7 @@ import ( func TestRunnable(t *testing.T) { // Make sure AsRunnable() does not use leader election. - runner := server.NewDefaultExtProcServerRunner().AsRunnable(logutil.NewTestLogger(), nil, nil) + runner := server.NewDefaultExtProcServerRunner().AsRunnable(logutil.NewTestLogger()) r, ok := runner.(manager.LeaderElectionRunnable) if !ok { t.Fatal("runner is not LeaderElectionRunnable") diff --git a/pkg/ext-proc/test/benchmark/benchmark.go b/pkg/ext-proc/test/benchmark/benchmark.go index 9eca2edc6..a48f0465b 100644 --- a/pkg/ext-proc/test/benchmark/benchmark.go +++ b/pkg/ext-proc/test/benchmark/benchmark.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "os" @@ -12,6 +13,7 @@ import ( "github.com/jhump/protoreflect/desc" uberzap "go.uber.org/zap" "google.golang.org/protobuf/proto" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" @@ -48,11 +50,11 @@ func run() error { } opts.BindFlags(flag.CommandLine) flag.Parse() - logger := zap.New(zap.UseFlagOptions(&opts), zap.RawZapOpts(uberzap.AddCaller())) + ctx := log.IntoContext(context.Background(), logger) if *localServer { - test.StartExtProc(logger, port, *refreshPodsInterval, *refreshMetricsInterval, *refreshPrometheusMetricsInterval, fakePods(), fakeModels()) + test.StartExtProc(ctx, port, *refreshPodsInterval, *refreshMetricsInterval, *refreshPrometheusMetricsInterval, fakePods(), fakeModels()) time.Sleep(time.Second) // wait until server is up logger.Info("Server started") } @@ -81,7 +83,7 @@ func run() error { func generateRequestFunc(logger logr.Logger) func(mtd *desc.MethodDescriptor, callData *runner.CallData) []byte { return func(mtd *desc.MethodDescriptor, callData *runner.CallData) []byte { numModels := *numFakePods * (*numModelsPerPod) - req := test.GenerateRequest(logger, modelName(int(callData.RequestNumber)%numModels)) + req := test.GenerateRequest(logger, "hello", modelName(int(callData.RequestNumber)%numModels)) data, err := proto.Marshal(req) if err != nil { logutil.Fatal(logger, err, "Failed to marshal request", "request", req) @@ -105,9 +107,7 @@ func fakeModels() map[string]*v1alpha1.InferenceModel { func fakePods() []*backend.PodMetrics { pms := make([]*backend.PodMetrics, 0, *numFakePods) for i := 0; i < *numFakePods; i++ { - metrics := fakeMetrics(i) - pod := test.FakePod(i) - pms = append(pms, &backend.PodMetrics{Pod: pod, Metrics: metrics}) + pms = append(pms, test.FakePodMetrics(i, fakeMetrics(i))) } return pms diff --git a/pkg/ext-proc/test/utils.go b/pkg/ext-proc/test/utils.go index cb99a36be..af3e44016 100644 --- a/pkg/ext-proc/test/utils.go +++ b/pkg/ext-proc/test/utils.go @@ -1,6 +1,7 @@ package test import ( + "context" "encoding/json" "fmt" "net" @@ -10,36 +11,50 @@ import ( "github.com/go-logr/logr" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/handlers" "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/scheduling" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" + utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" ) func StartExtProc( - logger logr.Logger, + ctx context.Context, port int, refreshPodsInterval, refreshMetricsInterval, refreshPrometheusMetricsInterval time.Duration, pods []*backend.PodMetrics, models map[string]*v1alpha1.InferenceModel, ) *grpc.Server { - ps := make(backend.PodSet) - pms := make(map[backend.Pod]*backend.PodMetrics) + logger := log.FromContext(ctx) + pms := make(map[types.NamespacedName]*backend.PodMetrics) for _, pod := range pods { - ps[pod.Pod] = true - pms[pod.Pod] = pod + pms[pod.NamespacedName] = pod } pmc := &backend.FakePodMetricsClient{Res: pms} - pp := backend.NewProvider(pmc, backend.NewK8sDataStore(backend.WithPods(pods))) - if err := pp.Init(logger, refreshPodsInterval, refreshMetricsInterval, refreshPrometheusMetricsInterval); err != nil { + datastore := backend.NewDatastore() + for _, m := range models { + datastore.ModelSet(m) + } + for _, pm := range pods { + pod := utiltesting.MakePod(pm.NamespacedName.Name, pm.NamespacedName.Namespace). + ReadyCondition(). + IP(pm.Address). + Obj() + datastore.PodAddIfNotExist(&pod) + datastore.PodUpdateMetricsIfExist(pm) + } + pp := backend.NewProvider(pmc, datastore) + if err := pp.Init(ctx, refreshMetricsInterval, refreshPrometheusMetricsInterval); err != nil { logutil.Fatal(logger, err, "Failed to initialize") } - return startExtProc(logger, port, pp, models) + return startExtProc(logger, port, datastore) } // startExtProc starts an extProc server with fake pods. -func startExtProc(logger logr.Logger, port int, pp *backend.Provider, models map[string]*v1alpha1.InferenceModel) *grpc.Server { +func startExtProc(logger logr.Logger, port int, datastore backend.Datastore) *grpc.Server { lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { logutil.Fatal(logger, err, "Failed to listen", "port", port) @@ -47,7 +62,7 @@ func startExtProc(logger logr.Logger, port int, pp *backend.Provider, models map s := grpc.NewServer() - extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(pp, scheduling.NewScheduler(pp), "target-pod", &backend.FakeDataStore{Res: models})) + extProcPb.RegisterExternalProcessorServer(s, handlers.NewServer(scheduling.NewScheduler(datastore), "target-pod", datastore)) logger.Info("gRPC server starting", "port", port) reflection.Register(s) @@ -60,10 +75,10 @@ func startExtProc(logger logr.Logger, port int, pp *backend.Provider, models map return s } -func GenerateRequest(logger logr.Logger, model string) *extProcPb.ProcessingRequest { +func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.ProcessingRequest { j := map[string]interface{}{ "model": model, - "prompt": "hello", + "prompt": prompt, "max_tokens": 100, "temperature": 0, } @@ -80,11 +95,12 @@ func GenerateRequest(logger logr.Logger, model string) *extProcPb.ProcessingRequ return req } -func FakePod(index int) backend.Pod { +func FakePodMetrics(index int, metrics backend.Metrics) *backend.PodMetrics { address := fmt.Sprintf("address-%v", index) - pod := backend.Pod{ - Name: fmt.Sprintf("pod-%v", index), - Address: address, + pod := backend.PodMetrics{ + NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index)}, + Address: address, + Metrics: metrics, } - return pod + return &pod } diff --git a/pkg/ext-proc/util/testing/wrappers.go b/pkg/ext-proc/util/testing/wrappers.go new file mode 100644 index 000000000..f9005499d --- /dev/null +++ b/pkg/ext-proc/util/testing/wrappers.go @@ -0,0 +1,50 @@ +package testing + +import ( + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// PodWrapper wraps a Pod. +type PodWrapper struct { + corev1.Pod +} + +// MakePod creates a wrapper for a Pod. +func MakePod(podName, ns string) *PodWrapper { + return &PodWrapper{ + corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: podName, + Namespace: ns, + }, + Spec: corev1.PodSpec{}, + Status: corev1.PodStatus{}, + }, + } +} + +// Labels sets the pod labels. +func (p *PodWrapper) Labels(labels map[string]string) *PodWrapper { + p.ObjectMeta.Labels = labels + return p +} + +// SetReadyCondition sets a PodReay=true condition. +func (p *PodWrapper) ReadyCondition() *PodWrapper { + p.Status.Conditions = []corev1.PodCondition{{ + Type: corev1.PodReady, + Status: corev1.ConditionTrue, + }} + return p +} + +func (p *PodWrapper) IP(ip string) *PodWrapper { + p.Status.PodIP = ip + return p +} + +// Obj returns the wrapped Pod. +func (p *PodWrapper) Obj() corev1.Pod { + return p.Pod +} diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index a99b6bd7d..917937a3b 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -22,6 +22,7 @@ import ( "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" k8syaml "k8s.io/apimachinery/pkg/util/yaml" clientgoscheme "k8s.io/client-go/kubernetes/scheme" @@ -33,6 +34,7 @@ import ( runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/server" extprocutils "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/test" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" + utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" "sigs.k8s.io/yaml" ) @@ -61,36 +63,27 @@ func TestKubeInferenceModelRequest(t *testing.T) { }{ { name: "select lower queue and kv cache, no active lora", - req: extprocutils.GenerateRequest(logger, "my-model"), + req: extprocutils.GenerateRequest(logger, "test1", "my-model"), // pod-1 will be picked because it has relatively low queue size and low KV cache. pods: []*backend.PodMetrics{ - { - Pod: extprocutils.FakePod(0), - Metrics: backend.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.2, - }, - }, - { - Pod: extprocutils.FakePod(1), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.1, - }, - }, - { - Pod: extprocutils.FakePod(2), - Metrics: backend.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - }, - }, + extprocutils.FakePodMetrics(0, backend.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.2, + }), + extprocutils.FakePodMetrics(1, backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.1, + }), + extprocutils.FakePodMetrics(2, backend.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + }), }, wantHeaders: []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("address-1"), + RawValue: []byte("address-1:8000"), }, }, { @@ -104,58 +97,49 @@ func TestKubeInferenceModelRequest(t *testing.T) { Fields: map[string]*structpb.Value{ runserver.DefaultTargetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: "address-1", + StringValue: "address-1:8000", }, }, }, }, - wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model-12345\",\"prompt\":\"hello\",\"temperature\":0}"), + wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model-12345\",\"prompt\":\"test1\",\"temperature\":0}"), wantErr: false, }, { name: "select active lora, low queue", - req: extprocutils.GenerateRequest(logger, "sql-lora"), + req: extprocutils.GenerateRequest(logger, "test2", "sql-lora"), // pod-1 will be picked because it has relatively low queue size, with the requested // model being active, and has low KV cache. pods: []*backend.PodMetrics{ - { - Pod: extprocutils.FakePod(0), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, + extprocutils.FakePodMetrics(0, backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, }, - }, - { - Pod: extprocutils.FakePod(1), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.1, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg2": 1, - }, + }), + extprocutils.FakePodMetrics(1, backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, }, - }, - { - Pod: extprocutils.FakePod(2), - Metrics: backend.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - }, + }), + extprocutils.FakePodMetrics(2, backend.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, }, - }, + }), }, wantHeaders: []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("address-1"), + RawValue: []byte("address-1:8000"), }, }, { @@ -169,59 +153,50 @@ func TestKubeInferenceModelRequest(t *testing.T) { Fields: map[string]*structpb.Value{ runserver.DefaultTargetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: "address-1", + StringValue: "address-1:8000", }, }, }, }, - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"hello\",\"temperature\":0}"), + wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test2\",\"temperature\":0}"), wantErr: false, }, { name: "select no lora despite active model, avoid excessive queue size", - req: extprocutils.GenerateRequest(logger, "sql-lora"), + req: extprocutils.GenerateRequest(logger, "test3", "sql-lora"), // pod-2 will be picked despite it NOT having the requested model being active // as it's above the affinity for queue size. Also is critical, so we should // still honor request despite all queues > 5 pods: []*backend.PodMetrics{ - { - Pod: extprocutils.FakePod(0), - Metrics: backend.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, + extprocutils.FakePodMetrics(0, backend.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, }, - }, - { - Pod: extprocutils.FakePod(1), - Metrics: backend.Metrics{ - WaitingQueueSize: 50, - KVCacheUsagePercent: 0.1, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg2": 1, - }, + }), + extprocutils.FakePodMetrics(1, backend.Metrics{ + WaitingQueueSize: 50, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, }, - }, - { - Pod: extprocutils.FakePod(2), - Metrics: backend.Metrics{ - WaitingQueueSize: 6, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - }, + }), + extprocutils.FakePodMetrics(2, backend.Metrics{ + WaitingQueueSize: 6, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, }, - }, + }), }, wantHeaders: []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("address-2"), + RawValue: []byte("address-2:8000"), }, }, { @@ -235,54 +210,45 @@ func TestKubeInferenceModelRequest(t *testing.T) { Fields: map[string]*structpb.Value{ runserver.DefaultTargetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: "address-2", + StringValue: "address-2:8000", }, }, }, }, - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"hello\",\"temperature\":0}"), + wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test3\",\"temperature\":0}"), wantErr: false, }, { name: "noncritical and all models past threshold, shed request", - req: extprocutils.GenerateRequest(logger, "sql-lora-sheddable"), + req: extprocutils.GenerateRequest(logger, "test4", "sql-lora-sheddable"), // no pods will be picked as all models are either above kv threshold, // queue threshold, or both. pods: []*backend.PodMetrics{ - { - Pod: extprocutils.FakePod(0), - Metrics: backend.Metrics{ - WaitingQueueSize: 6, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - "sql-lora-1fdg3": 1, - }, + extprocutils.FakePodMetrics(0, backend.Metrics{ + WaitingQueueSize: 6, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, }, - }, - { - Pod: extprocutils.FakePod(1), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.85, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, + }), + extprocutils.FakePodMetrics(1, backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, }, - }, - { - Pod: extprocutils.FakePod(2), - Metrics: backend.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.9, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, + }), + extprocutils.FakePodMetrics(2, backend.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, }, - }, + }), }, wantHeaders: []*configPb.HeaderValueOption{}, wantMetadata: &structpb.Struct{}, @@ -296,49 +262,40 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, { name: "noncritical, but one server has capacity, do not shed", - req: extprocutils.GenerateRequest(logger, "sql-lora-sheddable"), + req: extprocutils.GenerateRequest(logger, "test5", "sql-lora-sheddable"), // pod 0 will be picked as all other models are above threshold pods: []*backend.PodMetrics{ - { - Pod: extprocutils.FakePod(0), - Metrics: backend.Metrics{ - WaitingQueueSize: 4, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - "sql-lora-1fdg3": 1, - }, + extprocutils.FakePodMetrics(0, backend.Metrics{ + WaitingQueueSize: 4, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, }, - }, - { - Pod: extprocutils.FakePod(1), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.85, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, + }), + extprocutils.FakePodMetrics(1, backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.85, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, }, - }, - { - Pod: extprocutils.FakePod(2), - Metrics: backend.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.9, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, + }), + extprocutils.FakePodMetrics(2, backend.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg3": 1, }, - }, + }), }, wantHeaders: []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("address-0"), + RawValue: []byte("address-0:8000"), }, }, { @@ -352,18 +309,19 @@ func TestKubeInferenceModelRequest(t *testing.T) { Fields: map[string]*structpb.Value{ runserver.DefaultTargetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: "address-0", + StringValue: "address-0:8000", }, }, }, }, - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"hello\",\"temperature\":0}"), + wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"test5\",\"temperature\":0}"), wantErr: false, }, } // Set up global k8sclient and extproc server runner with test environment config - BeforeSuit() + cleanup := BeforeSuit() + defer cleanup() for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -405,27 +363,30 @@ func TestKubeInferenceModelRequest(t *testing.T) { } } -func setUpHermeticServer(pods []*backend.PodMetrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { - ps := make(backend.PodSet) - pms := make(map[backend.Pod]*backend.PodMetrics) - for _, pod := range pods { - ps[pod.Pod] = true - pms[pod.Pod] = pod +func setUpHermeticServer(podMetrics []*backend.PodMetrics) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { + pms := make(map[types.NamespacedName]*backend.PodMetrics) + for _, pm := range podMetrics { + pms[pm.NamespacedName] = pm } pmc := &backend.FakePodMetricsClient{Res: pms} serverCtx, stopServer := context.WithCancel(context.Background()) go func() { - if err := serverRunner.AsRunnable( - logger.WithName("ext-proc"), backend.NewK8sDataStore(backend.WithPods(pods)), pmc, - ).Start(serverCtx); err != nil { + serverRunner.Datastore.PodDeleteAll() + for _, pm := range podMetrics { + pod := utiltesting.MakePod(pm.NamespacedName.Name, pm.NamespacedName.Namespace). + ReadyCondition(). + IP(pm.Address). + Obj() + serverRunner.Datastore.PodAddIfNotExist(&pod) + serverRunner.Datastore.PodUpdateMetricsIfExist(pm) + } + serverRunner.Provider = backend.NewProvider(pmc, serverRunner.Datastore) + if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { logutil.Fatal(logger, err, "Failed to start ext-proc server") } }() - // Wait the reconciler to populate the datastore. - time.Sleep(10 * time.Second) - address := fmt.Sprintf("localhost:%v", port) // Create a grpc connection conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -442,11 +403,13 @@ func setUpHermeticServer(pods []*backend.PodMetrics) (client extProcPb.ExternalP cancel() conn.Close() stopServer() + // wait a little until the goroutines actually exit + time.Sleep(5 * time.Second) } } // Sets up a test environment and returns the runner struct -func BeforeSuit() { +func BeforeSuit() func() { // Set up mock k8s API Client testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, @@ -477,7 +440,7 @@ func BeforeSuit() { serverRunner = runserver.NewDefaultExtProcServerRunner() // Adjust from defaults serverRunner.PoolName = "vllm-llama2-7b-pool" - serverRunner.Datastore = backend.NewK8sDataStore() + serverRunner.Datastore = backend.NewDatastore() serverRunner.SecureServing = false if err := serverRunner.SetupWithManager(mgr); err != nil { @@ -524,6 +487,25 @@ func BeforeSuit() { } } } + + if !blockUntilPoolSyncs(serverRunner.Datastore) { + logutil.Fatal(logger, nil, "Timeout waiting for the pool and models to sync") + } + + return func() { + _ = testEnv.Stop() + } +} + +func blockUntilPoolSyncs(datastore backend.Datastore) bool { + // We really need to move those tests to gingo so we can use Eventually... + for i := 1; i < 10; i++ { + if datastore.PoolHasSynced() && datastore.ModelGet("my-model") != nil { + return true + } + time.Sleep(1 * time.Second) + } + return false } func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { From 89b425b6fc5ebcdb0a52275a465f6923c46e11be Mon Sep 17 00:00:00 2001 From: ahg-g Date: Mon, 17 Feb 2025 20:46:00 +0000 Subject: [PATCH 2/8] Fixed the provider test and covered the pool deletion events. --- pkg/ext-proc/backend/datastore.go | 17 +++- .../backend/inferencepool_reconciler.go | 16 +++- .../backend/inferencepool_reconciler_test.go | 39 ++++++++- pkg/ext-proc/backend/provider_test.go | 86 ++++++++++++------- test/integration/hermetic_test.go | 23 ++--- 5 files changed, 127 insertions(+), 54 deletions(-) diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index 05e0a9c38..d84142aef 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -35,10 +35,13 @@ type Datastore interface { PodUpdateMetricsIfExist(pm *PodMetrics) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) PodDelete(namespacedName types.NamespacedName) - PodFlush(ctx context.Context, ctrlClient client.Client) + PodFlushAll(ctx context.Context, ctrlClient client.Client) PodGetAll() []*PodMetrics - PodRange(f func(key, value any) bool) PodDeleteAll() // This is only for testing. + PodRange(f func(key, value any) bool) + + // Clears the store state, happens when the pool gets deleted. + Clear() } func NewDatastore() Datastore { @@ -59,6 +62,14 @@ type datastore struct { pods *sync.Map } +func (ds *datastore) Clear() { + ds.poolMu.Lock() + defer ds.poolMu.Unlock() + ds.pool = nil + ds.models.Clear() + ds.pods.Clear() +} + // /// InferencePool APIs /// func (ds *datastore) PoolSet(pool *v1alpha1.InferencePool) { ds.poolMu.Lock() @@ -158,7 +169,7 @@ func (ds *datastore) PodAddIfNotExist(pod *corev1.Pod) bool { return false } -func (ds *datastore) PodFlush(ctx context.Context, ctrlClient client.Client) { +func (ds *datastore) PodFlushAll(ctx context.Context, ctrlClient client.Client) { // Pool must exist to invoke this function. pool, _ := ds.PoolGet() podList := &corev1.PodList{} diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index 36a7a60c2..6ac805690 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -4,6 +4,7 @@ import ( "context" "reflect" + "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" @@ -36,12 +37,19 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques loggerDefault.Info("Reconciling InferencePool", "name", req.NamespacedName) serverPool := &v1alpha1.InferencePool{} + if err := c.Get(ctx, req.NamespacedName, serverPool); err != nil { + if errors.IsNotFound(err) { + loggerDefault.Info("InferencePool not found. Clearing the datastore", "name", req.NamespacedName) + c.Datastore.Clear() + return ctrl.Result{}, nil + } loggerDefault.Error(err, "Unable to get InferencePool", "name", req.NamespacedName) return ctrl.Result{}, err - - // TODO: Handle InferencePool deletions. Need to flush the datastore. - // TODO: Handle port updates, podMetrics should not be storing that as part of the address. + } else if !serverPool.DeletionTimestamp.IsZero() { + loggerDefault.Info("InferencePool is marked for deletion. Clearing the datastore", "name", req.NamespacedName) + c.Datastore.Clear() + return ctrl.Result{}, nil } c.updateDatastore(ctx, serverPool) @@ -55,7 +63,7 @@ func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool * c.Datastore.PoolSet(newPool) if oldPool == nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "target", klog.KMetadata(&newPool.ObjectMeta)) - c.Datastore.PodFlush(ctx, c.Client) + c.Datastore.PodFlushAll(ctx, c.Client) } } diff --git a/pkg/ext-proc/backend/inferencepool_reconciler_test.go b/pkg/ext-proc/backend/inferencepool_reconciler_test.go index c1c700a0c..b6403489b 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler_test.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler_test.go @@ -26,7 +26,10 @@ var ( Name: "pool1", Namespace: "pool1-ns", }, - Spec: v1alpha1.InferencePoolSpec{Selector: selector_v1}, + Spec: v1alpha1.InferencePoolSpec{ + Selector: selector_v1, + TargetPortNumber: 8080, + }, } pool2 = &v1alpha1.InferencePool{ ObjectMeta: metav1.ObjectMeta{ @@ -48,6 +51,9 @@ var ( ) func TestReconcile_InferencePoolReconciler(t *testing.T) { + // The best practice is to use table-driven tests, however in this scaenario it seems + // more logical to do a single test with steps that depend on each other. + // Set up the scheme. scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) @@ -63,7 +69,7 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { WithObjects(initialObjects...). Build() - // Create a request for the existing resource. + // Create a request for the existing resource. namespacedName := types.NamespacedName{Name: pool1.Name, Namespace: pool1.Namespace} req := ctrl.Request{NamespacedName: namespacedName} ctx := context.Background() @@ -103,6 +109,35 @@ func TestReconcile_InferencePoolReconciler(t *testing.T) { if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" { t.Errorf("Unexpected diff (+got/-want): %s", diff) } + + // Step 4: update the pool port + if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { + t.Errorf("Unexpected pool get error: %v", err) + } + newPool1.Spec.TargetPortNumber = 9090 + if err := fakeClient.Update(ctx, newPool1, &client.UpdateOptions{}); err != nil { + t.Errorf("Unexpected pool update error: %v", err) + } + if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { + t.Errorf("Unexpected InferencePool reconcile error: %v", err) + } + if diff := diffPool(datastore, newPool1, []string{"pod5"}); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } + + // Step 5: delete the pool to trigger a datastore clear + if err := fakeClient.Get(ctx, req.NamespacedName, newPool1); err != nil { + t.Errorf("Unexpected pool get error: %v", err) + } + if err := fakeClient.Delete(ctx, newPool1, &client.DeleteOptions{}); err != nil { + t.Errorf("Unexpected pool delete error: %v", err) + } + if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { + t.Errorf("Unexpected InferencePool reconcile error: %v", err) + } + if diff := diffPool(datastore, nil, []string{}); diff != "" { + t.Errorf("Unexpected diff (+got/-want): %s", diff) + } } func diffPool(datastore Datastore, wantPool *v1alpha1.InferencePool, wantPods []string) string { diff --git a/pkg/ext-proc/backend/provider_test.go b/pkg/ext-proc/backend/provider_test.go index d7d047802..e313d703b 100644 --- a/pkg/ext-proc/backend/provider_test.go +++ b/pkg/ext-proc/backend/provider_test.go @@ -1,13 +1,16 @@ package backend import ( + "context" + "errors" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/stretchr/testify/assert" "k8s.io/apimachinery/pkg/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" ) var ( @@ -42,8 +45,6 @@ var ( ) func TestProvider(t *testing.T) { - logger := logutil.NewTestLogger() - tests := []struct { name string pmc PodMetricsClient @@ -51,11 +52,8 @@ func TestProvider(t *testing.T) { want []*PodMetrics }{ { - name: "Fetch metrics error", + name: "Probing metrics success", pmc: &FakePodMetricsClient{ - // Err: map[string]error{ - // pod2.Name: errors.New("injected error"), - // }, Res: map[types.NamespacedName]*PodMetrics{ pod1.NamespacedName: pod1, pod2.NamespacedName: pod2, @@ -67,16 +65,47 @@ func TestProvider(t *testing.T) { want: []*PodMetrics{ pod1, pod2, - // // Failed to fetch pod2 metrics so it remains the default values. - // { - // Name: "pod2", - // Metrics: Metrics{ - // WaitingQueueSize: 0, - // KVCacheUsagePercent: 0, - // MaxActiveModels: 0, - // ActiveModels: map[string]int{}, - // }, - // }, + }, + }, + { + name: "Only pods in the datastore are probed", + pmc: &FakePodMetricsClient{ + Res: map[types.NamespacedName]*PodMetrics{ + pod1.NamespacedName: pod1, + pod2.NamespacedName: pod2, + }, + }, + datastore: &datastore{ + pods: populateMap(pod1), + }, + want: []*PodMetrics{ + pod1, + }, + }, + { + name: "Probing metrics error", + pmc: &FakePodMetricsClient{ + Err: map[types.NamespacedName]error{ + pod2.NamespacedName: errors.New("injected error"), + }, + Res: map[types.NamespacedName]*PodMetrics{ + pod1.NamespacedName: pod1, + }, + }, + datastore: &datastore{ + pods: populateMap(pod1, pod2), + }, + want: []*PodMetrics{ + pod1, + // Failed to fetch pod2 metrics so it remains the default values. + { + NamespacedName: pod2.NamespacedName, + Metrics: Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0, + MaxActiveModels: 0, + }, + }, }, }, } @@ -84,17 +113,16 @@ func TestProvider(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { p := NewProvider(test.pmc, test.datastore) - // if err := p.refreshMetricsOnce(logger); err != nil { - // t.Fatalf("Unexpected error: %v", err) - // } - _ = p.refreshMetricsOnce(logger) - metrics := test.datastore.PodGetAll() - lessFunc := func(a, b *PodMetrics) bool { - return a.String() < b.String() - } - if diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(lessFunc)); diff != "" { - t.Errorf("Unexpected output (-want +got): %v", diff) - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _ = p.Init(ctx, time.Millisecond, time.Millisecond) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + metrics := test.datastore.PodGetAll() + diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(func(a, b *PodMetrics) bool { + return a.String() < b.String() + })) + assert.Equal(t, "", diff, "Unexpected diff (+got/-want)") + }, 5*time.Second, time.Millisecond) }) } } @@ -102,7 +130,7 @@ func TestProvider(t *testing.T) { func populateMap(pods ...*PodMetrics) *sync.Map { newMap := &sync.Map{} for _, pod := range pods { - newMap.Store(pod.NamespacedName, pod) + newMap.Store(pod.NamespacedName, &PodMetrics{NamespacedName: pod.NamespacedName}) } return newMap } diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index 917937a3b..a87398922 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -17,6 +17,7 @@ import ( extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/testing/protocmp" @@ -320,7 +321,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { } // Set up global k8sclient and extproc server runner with test environment config - cleanup := BeforeSuit() + cleanup := BeforeSuit(t) defer cleanup() for _, test := range tests { @@ -409,7 +410,7 @@ func setUpHermeticServer(podMetrics []*backend.PodMetrics) (client extProcPb.Ext } // Sets up a test environment and returns the runner struct -func BeforeSuit() func() { +func BeforeSuit(t *testing.T) func() { // Set up mock k8s API Client testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, @@ -488,26 +489,16 @@ func BeforeSuit() func() { } } - if !blockUntilPoolSyncs(serverRunner.Datastore) { - logutil.Fatal(logger, nil, "Timeout waiting for the pool and models to sync") - } + assert.EventuallyWithT(t, func(t *assert.CollectT) { + synced := serverRunner.Datastore.PoolHasSynced() && serverRunner.Datastore.ModelGet("my-model") != nil + assert.True(t, synced, "Timeout waiting for the pool and models to sync") + }, 10*time.Second, 10*time.Millisecond) return func() { _ = testEnv.Stop() } } -func blockUntilPoolSyncs(datastore backend.Datastore) bool { - // We really need to move those tests to gingo so we can use Eventually... - for i := 1; i < 10; i++ { - if datastore.PoolHasSynced() && datastore.ModelGet("my-model") != nil { - return true - } - time.Sleep(1 * time.Second) - } - return false -} - func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { t.Logf("Sending request: %v", req) if err := client.Send(req); err != nil { From f62bf4f893af0ae8bd6516ad36309382398f24dc Mon Sep 17 00:00:00 2001 From: ahg-g Date: Mon, 17 Feb 2025 22:09:02 +0000 Subject: [PATCH 3/8] Don't store the port number with the pods --- pkg/ext-proc/backend/datastore.go | 5 +---- pkg/ext-proc/handlers/request.go | 16 ++++++++++++---- pkg/ext-proc/handlers/server.go | 2 +- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index d84142aef..d078541e8 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -4,7 +4,6 @@ import ( "context" "errors" "math/rand" - "strconv" "sync" "github.com/go-logr/logr" @@ -150,14 +149,12 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { } func (ds *datastore) PodAddIfNotExist(pod *corev1.Pod) bool { - // new pod, add to the store for probing - pool, _ := ds.PoolGet() new := &PodMetrics{ NamespacedName: types.NamespacedName{ Name: pod.Name, Namespace: pod.Namespace, }, - Address: pod.Status.PodIP + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)), + Address: pod.Status.PodIP, Metrics: Metrics{ ActiveModels: make(map[string]int), }, diff --git a/pkg/ext-proc/handlers/request.go b/pkg/ext-proc/handlers/request.go index def75c2fc..4910d0c78 100644 --- a/pkg/ext-proc/handlers/request.go +++ b/pkg/ext-proc/handlers/request.go @@ -82,21 +82,29 @@ func (s *Server) HandleRequestBody( if err != nil { return nil, fmt.Errorf("failed to find target pod: %w", err) } + logger.V(logutil.DEFAULT).Info("Request handled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) + // Insert target endpoint to instruct Envoy to route requests to the specified target pod. + // Attach the port number + pool, err := s.datastore.PoolGet() + if err != nil { + return nil, err + } + endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) + reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel reqCtx.RequestSize = len(v.RequestBody.Body) reqCtx.TargetPod = targetPod.NamespacedName.String() - reqCtx.TargetPodAddress = targetPod.Address + reqCtx.TargetEndpoint = endpoint - // Insert target endpoint to instruct Envoy to route requests to the specified target pod. headers := []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ Key: s.targetEndpointKey, - RawValue: []byte(targetPod.Address), + RawValue: []byte(endpoint), }, }, // We need to update the content length header if the body is mutated, see Envoy doc: @@ -135,7 +143,7 @@ func (s *Server) HandleRequestBody( Fields: map[string]*structpb.Value{ s.targetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: targetPod.Address, + StringValue: endpoint, }, }, }, diff --git a/pkg/ext-proc/handlers/server.go b/pkg/ext-proc/handlers/server.go index 047331679..fe00ebeb3 100644 --- a/pkg/ext-proc/handlers/server.go +++ b/pkg/ext-proc/handlers/server.go @@ -128,7 +128,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { // RequestContext stores context information during the life time of an HTTP request. type RequestContext struct { TargetPod string - TargetPodAddress string + TargetEndpoint string Model string ResolvedTargetModel string RequestReceivedTimestamp time.Time From 0853b36377831ce14e05cb736b0e2c03c3066344 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Mon, 17 Feb 2025 22:55:41 +0000 Subject: [PATCH 4/8] Address pod ip address updates --- pkg/ext-proc/backend/datastore.go | 22 +++++--- pkg/ext-proc/backend/pod_reconciler.go | 2 +- pkg/ext-proc/backend/pod_reconciler_test.go | 62 ++++++++++++++++----- pkg/ext-proc/backend/provider_test.go | 16 ++++-- pkg/ext-proc/backend/types.go | 14 +++-- pkg/ext-proc/scheduling/filter_test.go | 22 ++++---- pkg/ext-proc/test/utils.go | 10 ++-- test/integration/hermetic_test.go | 2 +- 8 files changed, 101 insertions(+), 49 deletions(-) diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index d078541e8..75106a7b5 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -30,7 +30,7 @@ type Datastore interface { ModelDelete(modelName string) // PodMetrics operations - PodAddIfNotExist(pod *corev1.Pod) bool + PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool PodUpdateMetricsIfExist(pm *PodMetrics) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) PodDelete(namespacedName types.NamespacedName) @@ -148,21 +148,27 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { ds.pods.Delete(namespacedName) } -func (ds *datastore) PodAddIfNotExist(pod *corev1.Pod) bool { +func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { new := &PodMetrics{ - NamespacedName: types.NamespacedName{ - Name: pod.Name, - Namespace: pod.Namespace, + Pod: Pod{ + NamespacedName: types.NamespacedName{ + Name: pod.Name, + Namespace: pod.Namespace, + }, + Address: pod.Status.PodIP, }, - Address: pod.Status.PodIP, Metrics: Metrics{ ActiveModels: make(map[string]int), }, } - if _, ok := ds.pods.Load(new.NamespacedName); !ok { + existing, ok := ds.pods.Load(new.NamespacedName) + if !ok { ds.pods.Store(new.NamespacedName, new) return true } + + // Update pod properties if anything changed. + existing.(*PodMetrics).Pod = new.Pod return false } @@ -182,7 +188,7 @@ func (ds *datastore) PodFlushAll(ctx context.Context, ctrlClient client.Client) for _, pod := range podList.Items { if podIsReady(&pod) { activePods[pod.Name] = true - ds.PodAddIfNotExist(&pod) + ds.PodUpdateOrAddIfNotExist(&pod) } } diff --git a/pkg/ext-proc/backend/pod_reconciler.go b/pkg/ext-proc/backend/pod_reconciler.go index 9bfe3dc89..8705ce838 100644 --- a/pkg/ext-proc/backend/pod_reconciler.go +++ b/pkg/ext-proc/backend/pod_reconciler.go @@ -61,7 +61,7 @@ func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { logger.V(logutil.DEFAULT).Info("Pod removed or not added", "name", namespacedName) c.Datastore.PodDelete(namespacedName) } else { - if c.Datastore.PodAddIfNotExist(pod) { + if c.Datastore.PodUpdateOrAddIfNotExist(pod) { logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) } else { logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) diff --git a/pkg/ext-proc/backend/pod_reconciler_test.go b/pkg/ext-proc/backend/pod_reconciler_test.go index c2522fbba..cc7381f66 100644 --- a/pkg/ext-proc/backend/pod_reconciler_test.go +++ b/pkg/ext-proc/backend/pod_reconciler_test.go @@ -18,9 +18,10 @@ import ( ) var ( - basePod1 = &PodMetrics{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: ":8000"} - basePod2 = &PodMetrics{NamespacedName: types.NamespacedName{Name: "pod2"}, Address: ":8000"} - basePod3 = &PodMetrics{NamespacedName: types.NamespacedName{Name: "pod3"}, Address: ":8000"} + basePod1 = &PodMetrics{Pod: Pod{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: "address-1"}} + basePod2 = &PodMetrics{Pod: Pod{NamespacedName: types.NamespacedName{Name: "pod2"}, Address: "address-2"}} + basePod3 = &PodMetrics{Pod: Pod{NamespacedName: types.NamespacedName{Name: "pod3"}, Address: "address-3"}} + basePod11 = &PodMetrics{Pod: Pod{NamespacedName: types.NamespacedName{Name: "pod1"}, Address: "address-11"}} ) func TestUpdateDatastore_PodReconciler(t *testing.T) { @@ -29,7 +30,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { name string datastore Datastore incomingPod *corev1.Pod - wantPods []types.NamespacedName + wantPods []Pod req *ctrl.Request }{ { @@ -47,12 +48,45 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, incomingPod: &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ - Name: "pod3", + Name: basePod3.NamespacedName.Name, + Labels: map[string]string{ + "some-key": "some-val", + }, + }, + Status: corev1.PodStatus{ + PodIP: basePod3.Address, + Conditions: []corev1.PodCondition{ + { + Type: corev1.PodReady, + Status: corev1.ConditionTrue, + }, + }, + }, + }, + wantPods: []Pod{basePod1.Pod, basePod2.Pod, basePod3.Pod}, + }, + { + name: "Update pod1 address", + datastore: &datastore{ + pods: populateMap(basePod1, basePod2), + pool: &v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{ + TargetPortNumber: int32(8000), + Selector: map[v1alpha1.LabelKey]v1alpha1.LabelValue{ + "some-key": "some-val", + }, + }, + }, + }, + incomingPod: &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: basePod11.NamespacedName.Name, Labels: map[string]string{ "some-key": "some-val", }, }, Status: corev1.PodStatus{ + PodIP: basePod11.Address, Conditions: []corev1.PodCondition{ { Type: corev1.PodReady, @@ -61,7 +95,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []types.NamespacedName{basePod1.NamespacedName, basePod2.NamespacedName, basePod3.NamespacedName}, + wantPods: []Pod{basePod11.Pod, basePod2.Pod}, }, { name: "Delete pod with DeletionTimestamp", @@ -94,7 +128,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []types.NamespacedName{basePod2.NamespacedName}, + wantPods: []Pod{basePod2.Pod}, }, { name: "Delete notfound pod", @@ -110,7 +144,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, req: &ctrl.Request{NamespacedName: types.NamespacedName{Name: "pod1"}}, - wantPods: []types.NamespacedName{basePod2.NamespacedName}, + wantPods: []Pod{basePod2.Pod}, }, { name: "New pod, not ready, valid selector", @@ -141,7 +175,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []types.NamespacedName{basePod1.NamespacedName, basePod2.NamespacedName}, + wantPods: []Pod{basePod1.Pod, basePod2.Pod}, }, { name: "Remove pod that does not match selector", @@ -172,7 +206,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []types.NamespacedName{basePod2.NamespacedName}, + wantPods: []Pod{basePod2.Pod}, }, { name: "Remove pod that is not ready", @@ -203,7 +237,7 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { }, }, }, - wantPods: []types.NamespacedName{basePod2.NamespacedName}, + wantPods: []Pod{basePod2.Pod}, }, } for _, test := range tests { @@ -229,15 +263,15 @@ func TestUpdateDatastore_PodReconciler(t *testing.T) { t.Errorf("Unexpected InferencePool reconcile error: %v", err) } - var gotPods []types.NamespacedName + var gotPods []Pod test.datastore.PodRange(func(k, v any) bool { pod := v.(*PodMetrics) if v != nil { - gotPods = append(gotPods, pod.NamespacedName) + gotPods = append(gotPods, pod.Pod) } return true }) - if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b types.NamespacedName) bool { return a.String() < b.String() })) { + if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b Pod) bool { return a.NamespacedName.String() < b.NamespacedName.String() })) { t.Errorf("got (%v) != want (%v);", gotPods, test.wantPods) } }) diff --git a/pkg/ext-proc/backend/provider_test.go b/pkg/ext-proc/backend/provider_test.go index e313d703b..2aa2c2139 100644 --- a/pkg/ext-proc/backend/provider_test.go +++ b/pkg/ext-proc/backend/provider_test.go @@ -15,8 +15,10 @@ import ( var ( pod1 = &PodMetrics{ - NamespacedName: types.NamespacedName{ - Name: "pod1", + Pod: Pod{ + NamespacedName: types.NamespacedName{ + Name: "pod1", + }, }, Metrics: Metrics{ WaitingQueueSize: 0, @@ -29,8 +31,10 @@ var ( }, } pod2 = &PodMetrics{ - NamespacedName: types.NamespacedName{ - Name: "pod2", + Pod: Pod{ + NamespacedName: types.NamespacedName{ + Name: "pod2", + }, }, Metrics: Metrics{ WaitingQueueSize: 1, @@ -99,7 +103,7 @@ func TestProvider(t *testing.T) { pod1, // Failed to fetch pod2 metrics so it remains the default values. { - NamespacedName: pod2.NamespacedName, + Pod: Pod{NamespacedName: pod2.NamespacedName}, Metrics: Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, @@ -130,7 +134,7 @@ func TestProvider(t *testing.T) { func populateMap(pods ...*PodMetrics) *sync.Map { newMap := &sync.Map{} for _, pod := range pods { - newMap.Store(pod.NamespacedName, &PodMetrics{NamespacedName: pod.NamespacedName}) + newMap.Store(pod.NamespacedName, &PodMetrics{Pod: Pod{NamespacedName: pod.NamespacedName, Address: pod.Address}}) } return newMap } diff --git a/pkg/ext-proc/backend/types.go b/pkg/ext-proc/backend/types.go index 053c80d28..0e02fb093 100644 --- a/pkg/ext-proc/backend/types.go +++ b/pkg/ext-proc/backend/types.go @@ -7,6 +7,11 @@ import ( "k8s.io/apimachinery/pkg/types" ) +type Pod struct { + NamespacedName types.NamespacedName + Address string +} + type Metrics struct { // ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU. ActiveModels map[string]int @@ -19,8 +24,7 @@ type Metrics struct { } type PodMetrics struct { - NamespacedName types.NamespacedName - Address string + Pod Metrics } @@ -34,8 +38,10 @@ func (pm *PodMetrics) Clone() *PodMetrics { cm[k] = v } clone := &PodMetrics{ - NamespacedName: pm.NamespacedName, - Address: pm.Address, + Pod: Pod{ + NamespacedName: pm.NamespacedName, + Address: pm.Address, + }, Metrics: Metrics{ ActiveModels: cm, RunningQueueSize: pm.RunningQueueSize, diff --git a/pkg/ext-proc/scheduling/filter_test.go b/pkg/ext-proc/scheduling/filter_test.go index 44f203cc1..9ed781c42 100644 --- a/pkg/ext-proc/scheduling/filter_test.go +++ b/pkg/ext-proc/scheduling/filter_test.go @@ -41,7 +41,7 @@ func TestFilter(t *testing.T) { // model being active, and has low KV cache. input: []*backend.PodMetrics{ { - NamespacedName: types.NamespacedName{Name: "pod1"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, Metrics: backend.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -53,7 +53,7 @@ func TestFilter(t *testing.T) { }, }, { - NamespacedName: types.NamespacedName{Name: "pod2"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, @@ -65,7 +65,7 @@ func TestFilter(t *testing.T) { }, }, { - NamespacedName: types.NamespacedName{Name: "pod3"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, @@ -78,7 +78,7 @@ func TestFilter(t *testing.T) { }, output: []*backend.PodMetrics{ { - NamespacedName: types.NamespacedName{Name: "pod2"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, @@ -102,7 +102,7 @@ func TestFilter(t *testing.T) { // pod1 will be picked because it has capacity for the sheddable request. input: []*backend.PodMetrics{ { - NamespacedName: types.NamespacedName{Name: "pod1"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, Metrics: backend.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -114,7 +114,7 @@ func TestFilter(t *testing.T) { }, }, { - NamespacedName: types.NamespacedName{Name: "pod2"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.1, @@ -126,7 +126,7 @@ func TestFilter(t *testing.T) { }, }, { - NamespacedName: types.NamespacedName{Name: "pod3"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, @@ -139,7 +139,7 @@ func TestFilter(t *testing.T) { }, output: []*backend.PodMetrics{ { - NamespacedName: types.NamespacedName{Name: "pod1"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, Metrics: backend.Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -164,7 +164,7 @@ func TestFilter(t *testing.T) { // dropped. input: []*backend.PodMetrics{ { - NamespacedName: types.NamespacedName{Name: "pod1"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.9, @@ -176,7 +176,7 @@ func TestFilter(t *testing.T) { }, }, { - NamespacedName: types.NamespacedName{Name: "pod2"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, Metrics: backend.Metrics{ WaitingQueueSize: 3, KVCacheUsagePercent: 0.85, @@ -188,7 +188,7 @@ func TestFilter(t *testing.T) { }, }, { - NamespacedName: types.NamespacedName{Name: "pod3"}, + Pod: backend.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, Metrics: backend.Metrics{ WaitingQueueSize: 10, KVCacheUsagePercent: 0.85, diff --git a/pkg/ext-proc/test/utils.go b/pkg/ext-proc/test/utils.go index af3e44016..ff7cd89a7 100644 --- a/pkg/ext-proc/test/utils.go +++ b/pkg/ext-proc/test/utils.go @@ -43,7 +43,7 @@ func StartExtProc( ReadyCondition(). IP(pm.Address). Obj() - datastore.PodAddIfNotExist(&pod) + datastore.PodUpdateOrAddIfNotExist(&pod) datastore.PodUpdateMetricsIfExist(pm) } pp := backend.NewProvider(pmc, datastore) @@ -98,9 +98,11 @@ func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.Proces func FakePodMetrics(index int, metrics backend.Metrics) *backend.PodMetrics { address := fmt.Sprintf("address-%v", index) pod := backend.PodMetrics{ - NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index)}, - Address: address, - Metrics: metrics, + Pod: backend.Pod{ + NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index)}, + Address: address, + }, + Metrics: metrics, } return &pod } diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index a87398922..67377a69f 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -379,7 +379,7 @@ func setUpHermeticServer(podMetrics []*backend.PodMetrics) (client extProcPb.Ext ReadyCondition(). IP(pm.Address). Obj() - serverRunner.Datastore.PodAddIfNotExist(&pod) + serverRunner.Datastore.PodUpdateOrAddIfNotExist(&pod) serverRunner.Datastore.PodUpdateMetricsIfExist(pm) } serverRunner.Provider = backend.NewProvider(pmc, serverRunner.Datastore) From 475c85d5a5a2d2b4217b2c5d22e907e4e4c7ec60 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Tue, 18 Feb 2025 16:44:41 +0000 Subject: [PATCH 5/8] rename PodFlushAll to PodResyncAll --- pkg/ext-proc/backend/datastore.go | 4 ++-- pkg/ext-proc/backend/inferencepool_reconciler.go | 2 +- pkg/manifests/ext_proc.yaml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index 75106a7b5..09a630c15 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -34,7 +34,7 @@ type Datastore interface { PodUpdateMetricsIfExist(pm *PodMetrics) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) PodDelete(namespacedName types.NamespacedName) - PodFlushAll(ctx context.Context, ctrlClient client.Client) + PodResyncAll(ctx context.Context, ctrlClient client.Client) PodGetAll() []*PodMetrics PodDeleteAll() // This is only for testing. PodRange(f func(key, value any) bool) @@ -172,7 +172,7 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { return false } -func (ds *datastore) PodFlushAll(ctx context.Context, ctrlClient client.Client) { +func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client) { // Pool must exist to invoke this function. pool, _ := ds.PoolGet() podList := &corev1.PodList{} diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index 6ac805690..83f9ee9e0 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -63,7 +63,7 @@ func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool * c.Datastore.PoolSet(newPool) if oldPool == nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "target", klog.KMetadata(&newPool.ObjectMeta)) - c.Datastore.PodFlushAll(ctx, c.Client) + c.Datastore.PodResyncAll(ctx, c.Client) } } diff --git a/pkg/manifests/ext_proc.yaml b/pkg/manifests/ext_proc.yaml index 49145d24c..a3c66bc9e 100644 --- a/pkg/manifests/ext_proc.yaml +++ b/pkg/manifests/ext_proc.yaml @@ -71,7 +71,7 @@ spec: spec: containers: - name: inference-gateway-ext-proc - image: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/epp:main + image: us-central1-docker.pkg.dev/ahg-gke-dev/jobset2/epp:dfee85a imagePullPolicy: Always args: - -poolName From 6159313b99d49046efa003ffcb634fb7e566b040 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Tue, 18 Feb 2025 18:56:05 +0000 Subject: [PATCH 6/8] Addressed first round of comments --- pkg/ext-proc/backend/datastore.go | 14 ++++++++------ .../backend/inferencemodel_reconciler_test.go | 4 ++-- pkg/ext-proc/handlers/request.go | 4 ++-- pkg/manifests/ext_proc.yaml | 2 +- pkg/manifests/vllm/deployment.yaml | 2 +- test/integration/hermetic_test.go | 3 ++- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index 09a630c15..cc4bf6a16 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -26,12 +26,12 @@ type Datastore interface { // InferenceModel operations ModelSet(infModel *v1alpha1.InferenceModel) - ModelGet(modelName string) (returnModel *v1alpha1.InferenceModel) + ModelGet(modelName string) (*v1alpha1.InferenceModel, bool) ModelDelete(modelName string) // PodMetrics operations PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool - PodUpdateMetricsIfExist(pm *PodMetrics) + PodUpdateMetricsIfExist(pm *PodMetrics) bool PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) PodDelete(namespacedName types.NamespacedName) PodResyncAll(ctx context.Context, ctrlClient client.Client) @@ -102,12 +102,12 @@ func (ds *datastore) ModelSet(infModel *v1alpha1.InferenceModel) { ds.models.Store(infModel.Spec.ModelName, infModel) } -func (ds *datastore) ModelGet(modelName string) (returnModel *v1alpha1.InferenceModel) { +func (ds *datastore) ModelGet(modelName string) (*v1alpha1.InferenceModel, bool) { infModel, ok := ds.models.Load(modelName) if ok { - returnModel = infModel.(*v1alpha1.InferenceModel) + return infModel.(*v1alpha1.InferenceModel), true } - return + return nil, false } func (ds *datastore) ModelDelete(modelName string) { @@ -115,11 +115,13 @@ func (ds *datastore) ModelDelete(modelName string) { } // /// Pods/endpoints APIs /// -func (ds *datastore) PodUpdateMetricsIfExist(pm *PodMetrics) { +func (ds *datastore) PodUpdateMetricsIfExist(pm *PodMetrics) bool { if val, ok := ds.pods.Load(pm.NamespacedName); ok { existing := val.(*PodMetrics) existing.Metrics = pm.Metrics + return true } + return false } func (ds *datastore) PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) { diff --git a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go index 67872636e..5afe3b5ab 100644 --- a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go +++ b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go @@ -242,7 +242,7 @@ func TestReconcile_ModelMarkedForDeletion(t *testing.T) { } // Verify that the datastore was not updated. - if infModel := datastore.ModelGet(existingModel.Spec.ModelName); infModel != nil { + if _, exist := datastore.ModelGet(existingModel.Spec.ModelName); exist { t.Errorf("expected datastore to not contain model %q", existingModel.Spec.ModelName) } } @@ -299,7 +299,7 @@ func TestReconcile_ResourceExists(t *testing.T) { } // Verify that the datastore was updated. - if infModel := datastore.ModelGet(existingModel.Spec.ModelName); infModel == nil { + if _, exist := datastore.ModelGet(existingModel.Spec.ModelName); !exist { t.Errorf("expected datastore to contain model %q", existingModel.Spec.ModelName) } } diff --git a/pkg/ext-proc/handlers/request.go b/pkg/ext-proc/handlers/request.go index 4910d0c78..5edb2e777 100644 --- a/pkg/ext-proc/handlers/request.go +++ b/pkg/ext-proc/handlers/request.go @@ -48,8 +48,8 @@ func (s *Server) HandleRequestBody( // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. // This might be a security risk in the future where adapters not registered in the InferenceModel // are able to be requested by using their distinct name. - modelObj := s.datastore.ModelGet(model) - if modelObj == nil { + modelObj, exist := s.datastore.ModelGet(model) + if !exist { return nil, fmt.Errorf("error finding a model object in InferenceModel for input %v", model) } if len(modelObj.Spec.TargetModels) > 0 { diff --git a/pkg/manifests/ext_proc.yaml b/pkg/manifests/ext_proc.yaml index a3c66bc9e..49145d24c 100644 --- a/pkg/manifests/ext_proc.yaml +++ b/pkg/manifests/ext_proc.yaml @@ -71,7 +71,7 @@ spec: spec: containers: - name: inference-gateway-ext-proc - image: us-central1-docker.pkg.dev/ahg-gke-dev/jobset2/epp:dfee85a + image: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/epp:main imagePullPolicy: Always args: - -poolName diff --git a/pkg/manifests/vllm/deployment.yaml b/pkg/manifests/vllm/deployment.yaml index a54d99b30..51689c9f2 100644 --- a/pkg/manifests/vllm/deployment.yaml +++ b/pkg/manifests/vllm/deployment.yaml @@ -26,7 +26,7 @@ spec: - "8000" - "--enable-lora" - "--max-loras" - - "2" + - "4" - "--max-cpu-loras" - "12" - "--lora-modules" diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index 67377a69f..9dbcf7833 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -490,7 +490,8 @@ func BeforeSuit(t *testing.T) func() { } assert.EventuallyWithT(t, func(t *assert.CollectT) { - synced := serverRunner.Datastore.PoolHasSynced() && serverRunner.Datastore.ModelGet("my-model") != nil + _, modelExist := serverRunner.Datastore.ModelGet("my-model") + synced := serverRunner.Datastore.PoolHasSynced() && modelExist assert.True(t, synced, "Timeout waiting for the pool and models to sync") }, 10*time.Second, 10*time.Millisecond) From 225f6b1f8cf0b59f0b5e5facde1d4b62fd901777 Mon Sep 17 00:00:00 2001 From: ahg-g Date: Tue, 18 Feb 2025 19:48:02 +0000 Subject: [PATCH 7/8] Addressed more comments --- pkg/ext-proc/backend/datastore.go | 8 ++++---- pkg/ext-proc/backend/inferencepool_reconciler.go | 7 +++---- pkg/ext-proc/backend/provider.go | 2 +- pkg/ext-proc/test/utils.go | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index cc4bf6a16..6b8483d3d 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -31,7 +31,7 @@ type Datastore interface { // PodMetrics operations PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool - PodUpdateMetricsIfExist(pm *PodMetrics) bool + PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool PodGet(namespacedName types.NamespacedName) (*PodMetrics, bool) PodDelete(namespacedName types.NamespacedName) PodResyncAll(ctx context.Context, ctrlClient client.Client) @@ -115,10 +115,10 @@ func (ds *datastore) ModelDelete(modelName string) { } // /// Pods/endpoints APIs /// -func (ds *datastore) PodUpdateMetricsIfExist(pm *PodMetrics) bool { - if val, ok := ds.pods.Load(pm.NamespacedName); ok { +func (ds *datastore) PodUpdateMetricsIfExist(namespacedName types.NamespacedName, m *Metrics) bool { + if val, ok := ds.pods.Load(namespacedName); ok { existing := val.(*PodMetrics) - existing.Metrics = pm.Metrics + existing.Metrics = *m return true } return false diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index 83f9ee9e0..167130d3a 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -8,7 +8,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" - klog "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" @@ -59,10 +58,10 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool *v1alpha1.InferencePool) { logger := log.FromContext(ctx) - oldPool, _ := c.Datastore.PoolGet() + oldPool, err := c.Datastore.PoolGet() c.Datastore.PoolSet(newPool) - if oldPool == nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { - logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "target", klog.KMetadata(&newPool.ObjectMeta)) + if err != nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { + logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", newPool.Spec.Selector) c.Datastore.PodResyncAll(ctx, c.Client) } } diff --git a/pkg/ext-proc/backend/provider.go b/pkg/ext-proc/backend/provider.go index 7e55947da..bb575d191 100644 --- a/pkg/ext-proc/backend/provider.go +++ b/pkg/ext-proc/backend/provider.go @@ -109,7 +109,7 @@ func (p *Provider) refreshMetricsOnce(logger logr.Logger) error { errCh <- fmt.Errorf("failed to parse metrics from %s: %v", existing.NamespacedName, err) return } - p.datastore.PodUpdateMetricsIfExist(updated) + p.datastore.PodUpdateMetricsIfExist(updated.NamespacedName, &updated.Metrics) loggerTrace.Info("Updated metrics for pod", "pod", updated.NamespacedName, "metrics", updated.Metrics) }() return true diff --git a/pkg/ext-proc/test/utils.go b/pkg/ext-proc/test/utils.go index ff7cd89a7..46affae91 100644 --- a/pkg/ext-proc/test/utils.go +++ b/pkg/ext-proc/test/utils.go @@ -44,7 +44,7 @@ func StartExtProc( IP(pm.Address). Obj() datastore.PodUpdateOrAddIfNotExist(&pod) - datastore.PodUpdateMetricsIfExist(pm) + datastore.PodUpdateMetricsIfExist(pm.NamespacedName, &pm.Metrics) } pp := backend.NewProvider(pmc, datastore) if err := pp.Init(ctx, refreshMetricsInterval, refreshPrometheusMetricsInterval); err != nil { From b30c90d49d9d8f8e30d0f3ee6f7d603a797857bc Mon Sep 17 00:00:00 2001 From: ahg-g Date: Tue, 18 Feb 2025 21:17:36 +0000 Subject: [PATCH 8/8] Adding a comment --- pkg/ext-proc/backend/inferencepool_reconciler.go | 6 ++++++ test/integration/hermetic_test.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index 167130d3a..6f52862e7 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -62,6 +62,12 @@ func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool * c.Datastore.PoolSet(newPool) if err != nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", newPool.Spec.Selector) + // A full resync is required to address two cases: + // 1) At startup, the pod events may get processed before the pool is synced with the datastore, + // and hence they will not be added to the store since pool selector is not known yet + // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need + // to resync the whole pool: remove pods in the store that don't match the new selector and add + // the ones that may have existed already to the store. c.Datastore.PodResyncAll(ctx, c.Client) } } diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index 9dbcf7833..0e30ac696 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -380,7 +380,7 @@ func setUpHermeticServer(podMetrics []*backend.PodMetrics) (client extProcPb.Ext IP(pm.Address). Obj() serverRunner.Datastore.PodUpdateOrAddIfNotExist(&pod) - serverRunner.Datastore.PodUpdateMetricsIfExist(pm) + serverRunner.Datastore.PodUpdateMetricsIfExist(pm.NamespacedName, &pm.Metrics) } serverRunner.Provider = backend.NewProvider(pmc, serverRunner.Datastore) if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil {