diff --git a/.golangci.yml b/.golangci.yml index 1462bcc77..2ad3b93da 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -14,7 +14,6 @@ linters: - dupword - durationcheck - fatcontext - - gci - ginkgolinter - gocritic - govet diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index b466a2ed5..627ddbe52 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -1,13 +1,26 @@ package backend import ( + "context" "errors" "math/rand" "sync" + "time" + "github.com/google/go-cmp/cmp" "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/informers" + informersv1 "k8s.io/client-go/informers/core/v1" + "k8s.io/client-go/kubernetes" + clientset "k8s.io/client-go/kubernetes" + listersv1 "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" ) @@ -15,8 +28,9 @@ func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore { store := &K8sDatastore{ poolMu: sync.RWMutex{}, InferenceModels: &sync.Map{}, - pods: &sync.Map{}, } + + store.podListerFactory = store.createPodLister for _, opt := range options { opt(store) } @@ -25,29 +39,68 @@ func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore { // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) type K8sDatastore struct { + client kubernetes.Interface // poolMu is used to synchronize access to the inferencePool. - poolMu sync.RWMutex - inferencePool *v1alpha1.InferencePool - InferenceModels *sync.Map - pods *sync.Map + poolMu sync.RWMutex + inferencePool *v1alpha1.InferencePool + podListerFactory PodListerFactory + podLister *PodLister + InferenceModels *sync.Map } type K8sDatastoreOption func(*K8sDatastore) +type PodListerFactory func(*v1alpha1.InferencePool) *PodLister // WithPods can be used in tests to override the pods. -func WithPods(pods []*PodMetrics) K8sDatastoreOption { +func WithPodListerFactory(factory PodListerFactory) K8sDatastoreOption { return func(store *K8sDatastore) { - store.pods = &sync.Map{} - for _, pod := range pods { - store.pods.Store(pod.Pod, true) - } + store.podListerFactory = factory } } +type PodLister struct { + Lister listersv1.PodLister + sharedInformer informers.SharedInformerFactory +} + +func (l *PodLister) listEverything() ([]*corev1.Pod, error) { + return l.Lister.List(labels.Everything()) + +} + +func (ds *K8sDatastore) SetClient(client kubernetes.Interface) { + ds.client = client +} + func (ds *K8sDatastore) setInferencePool(pool *v1alpha1.InferencePool) { ds.poolMu.Lock() defer ds.poolMu.Unlock() + + if ds.inferencePool != nil && cmp.Equal(ds.inferencePool.Spec.Selector, pool.Spec.Selector) { + // Pool updated, but the selector stayed the same, so no need to change the informer. + ds.inferencePool = pool + return + } + + // New pool or selector updated. ds.inferencePool = pool + + if ds.podLister != nil && ds.podLister.sharedInformer != nil { + // Shutdown the old informer async since this takes a few seconds. + go func() { + ds.podLister.sharedInformer.Shutdown() + }() + } + + if ds.podListerFactory != nil { + // Create a new informer with the new selector. + ds.podLister = ds.podListerFactory(ds.inferencePool) + if ds.podLister != nil && ds.podLister.sharedInformer != nil { + ctx := context.Background() + ds.podLister.sharedInformer.Start(ctx.Done()) + ds.podLister.sharedInformer.WaitForCacheSync(ctx.Done()) + } + } } func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) { @@ -59,13 +112,58 @@ func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) { return ds.inferencePool, nil } -func (ds *K8sDatastore) GetPodIPs() []string { - var ips []string - ds.pods.Range(func(name, pod any) bool { - ips = append(ips, pod.(*corev1.Pod).Status.PodIP) - return true - }) - return ips +func (ds *K8sDatastore) createPodLister(pool *v1alpha1.InferencePool) *PodLister { + if ds.client == nil { + return nil + } + klog.V(logutil.DEFAULT).Infof("Creating informer for pool %v", pool.Name) + selectorSet := make(map[string]string) + for k, v := range pool.Spec.Selector { + selectorSet[string(k)] = string(v) + } + + newPodInformer := func(cs clientset.Interface, resyncPeriod time.Duration) cache.SharedIndexInformer { + informer := informersv1.NewFilteredPodInformer(cs, pool.Namespace, resyncPeriod, cache.Indexers{}, func(options *metav1.ListOptions) { + options.LabelSelector = labels.SelectorFromSet(selectorSet).String() + }) + err := informer.SetTransform(func(obj interface{}) (interface{}, error) { + // Remove unnecessary fields to improve memory footprint. + if accessor, err := meta.Accessor(obj); err == nil { + if accessor.GetManagedFields() != nil { + accessor.SetManagedFields(nil) + } + } + return obj, nil + }) + if err != nil { + klog.Errorf("Failed to set pod transformer: %v", err) + } + return informer + } + // 0 means we disable resyncing, it is not really useful to resync every hour (the controller-runtime default), + // if things go wrong in the watch, no one will wait for an hour for things to get fixed. + // As precedence, kube-scheduler also disables this since it is expensive to list all pods from the api-server regularly. + resyncPeriod := time.Duration(0) + sharedInformer := informers.NewSharedInformerFactory(ds.client, resyncPeriod) + sharedInformer.InformerFor(&v1.Pod{}, newPodInformer) + + return &PodLister{ + Lister: sharedInformer.Core().V1().Pods().Lister(), + sharedInformer: sharedInformer, + } +} + +func (ds *K8sDatastore) getPods() ([]*corev1.Pod, error) { + ds.poolMu.RLock() + defer ds.poolMu.RUnlock() + if !ds.HasSynced() { + return nil, errors.New("InferencePool is not initialized in datastore") + } + pods, err := ds.podLister.listEverything() + if err != nil { + return nil, err + } + return pods, nil } func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) { diff --git a/pkg/ext-proc/backend/endpointslice_reconciler.go b/pkg/ext-proc/backend/endpointslice_reconciler.go deleted file mode 100644 index a2a9790f2..000000000 --- a/pkg/ext-proc/backend/endpointslice_reconciler.go +++ /dev/null @@ -1,109 +0,0 @@ -package backend - -import ( - "context" - "strconv" - "time" - - "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" - logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" - discoveryv1 "k8s.io/api/discovery/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/client-go/tools/record" - klog "k8s.io/klog/v2" - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/builder" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/predicate" -) - -var ( - serviceOwnerLabel = "kubernetes.io/service-name" -) - -type EndpointSliceReconciler struct { - client.Client - Scheme *runtime.Scheme - Record record.EventRecorder - ServiceName string - Zone string - Datastore *K8sDatastore -} - -func (c *EndpointSliceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - inferencePool, err := c.Datastore.getInferencePool() - if err != nil { - klog.V(logutil.DEFAULT).Infof("Skipping reconciling EndpointSlice because the InferencePool is not available yet: %v", err) - return ctrl.Result{Requeue: true, RequeueAfter: time.Second}, nil - } - - klog.V(logutil.DEFAULT).Info("Reconciling EndpointSlice ", req.NamespacedName) - - endpointSlice := &discoveryv1.EndpointSlice{} - if err := c.Get(ctx, req.NamespacedName, endpointSlice); err != nil { - klog.Errorf("Unable to get EndpointSlice: %v", err) - return ctrl.Result{}, err - } - c.updateDatastore(endpointSlice, inferencePool) - - return ctrl.Result{}, nil -} - -// TODO: Support multiple endpointslices for a single service -func (c *EndpointSliceReconciler) updateDatastore( - slice *discoveryv1.EndpointSlice, - inferencePool *v1alpha1.InferencePool) { - podMap := make(map[Pod]bool) - - for _, endpoint := range slice.Endpoints { - klog.V(logutil.DEFAULT).Infof("Zone: %v \n endpoint: %+v \n", c.Zone, endpoint) - if c.validPod(endpoint) { - pod := Pod{ - Name: endpoint.TargetRef.Name, - Address: endpoint.Addresses[0] + ":" + strconv.Itoa(int(inferencePool.Spec.TargetPortNumber)), - } - podMap[pod] = true - klog.V(logutil.DEFAULT).Infof("Storing pod %v", pod) - c.Datastore.pods.Store(pod, true) - } - } - - removeOldPods := func(k, v any) bool { - pod, ok := k.(Pod) - if !ok { - klog.Errorf("Unable to cast key to Pod: %v", k) - return false - } - if _, ok := podMap[pod]; !ok { - klog.V(logutil.DEFAULT).Infof("Removing pod %v", pod) - c.Datastore.pods.Delete(pod) - } - return true - } - c.Datastore.pods.Range(removeOldPods) -} - -func (c *EndpointSliceReconciler) SetupWithManager(mgr ctrl.Manager) error { - ownsEndPointSlice := func(object client.Object) bool { - // Check if the object is an EndpointSlice - endpointSlice, ok := object.(*discoveryv1.EndpointSlice) - if !ok { - return false - } - - gotLabel := endpointSlice.ObjectMeta.Labels[serviceOwnerLabel] - wantLabel := c.ServiceName - return gotLabel == wantLabel - } - - return ctrl.NewControllerManagedBy(mgr). - For(&discoveryv1.EndpointSlice{}, - builder.WithPredicates(predicate.NewPredicateFuncs(ownsEndPointSlice))). - Complete(c) -} - -func (c *EndpointSliceReconciler) validPod(endpoint discoveryv1.Endpoint) bool { - validZone := c.Zone == "" || c.Zone != "" && *endpoint.Zone == c.Zone - return validZone && *endpoint.Conditions.Ready - -} diff --git a/pkg/ext-proc/backend/endpointslice_reconcilier_test.go b/pkg/ext-proc/backend/endpointslice_reconcilier_test.go deleted file mode 100644 index e3c927ba8..000000000 --- a/pkg/ext-proc/backend/endpointslice_reconcilier_test.go +++ /dev/null @@ -1,202 +0,0 @@ -package backend - -import ( - "sync" - "testing" - - "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" - v1 "k8s.io/api/core/v1" - discoveryv1 "k8s.io/api/discovery/v1" -) - -var ( - basePod1 = Pod{Name: "pod1"} - basePod2 = Pod{Name: "pod2"} - basePod3 = Pod{Name: "pod3"} -) - -func TestUpdateDatastore_EndpointSliceReconciler(t *testing.T) { - tests := []struct { - name string - datastore *K8sDatastore - incomingSlice *discoveryv1.EndpointSlice - wantPods *sync.Map - }{ - { - name: "Add new pod", - datastore: &K8sDatastore{ - pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{ - TargetPortNumber: int32(8000), - }, - }, - }, - incomingSlice: &discoveryv1.EndpointSlice{ - Endpoints: []discoveryv1.Endpoint{ - { - TargetRef: &v1.ObjectReference{ - Name: "pod1", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - { - TargetRef: &v1.ObjectReference{ - Name: "pod2", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - { - TargetRef: &v1.ObjectReference{ - Name: "pod3", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - }, - }, - wantPods: populateMap(basePod1, basePod2, basePod3), - }, - { - name: "New pod, but its not ready yet. Do not add.", - datastore: &K8sDatastore{ - pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{ - TargetPortNumber: int32(8000), - }, - }, - }, - incomingSlice: &discoveryv1.EndpointSlice{ - Endpoints: []discoveryv1.Endpoint{ - { - TargetRef: &v1.ObjectReference{ - Name: "pod1", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - { - TargetRef: &v1.ObjectReference{ - Name: "pod2", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - { - TargetRef: &v1.ObjectReference{ - Name: "pod3", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: new(bool), - }, - Addresses: []string{"0.0.0.0"}, - }, - }, - }, - wantPods: populateMap(basePod1, basePod2), - }, - { - name: "Existing pod not ready, new pod added, and is ready", - datastore: &K8sDatastore{ - pods: populateMap(basePod1, basePod2), - inferencePool: &v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{ - TargetPortNumber: int32(8000), - }, - }, - }, - incomingSlice: &discoveryv1.EndpointSlice{ - Endpoints: []discoveryv1.Endpoint{ - { - TargetRef: &v1.ObjectReference{ - Name: "pod1", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: new(bool), - }, - Addresses: []string{"0.0.0.0"}, - }, - { - TargetRef: &v1.ObjectReference{ - Name: "pod2", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - { - TargetRef: &v1.ObjectReference{ - Name: "pod3", - }, - Zone: new(string), - Conditions: discoveryv1.EndpointConditions{ - Ready: truePointer(), - }, - Addresses: []string{"0.0.0.0"}, - }, - }, - }, - wantPods: populateMap(basePod3, basePod2), - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - endpointSliceReconciler := &EndpointSliceReconciler{Datastore: test.datastore, Zone: ""} - endpointSliceReconciler.updateDatastore(test.incomingSlice, test.datastore.inferencePool) - - if mapsEqual(endpointSliceReconciler.Datastore.pods, test.wantPods) { - t.Errorf("Unexpected output pod mismatch. \n Got %v \n Want: %v \n", - endpointSliceReconciler.Datastore.pods, - test.wantPods) - } - }) - } -} - -func mapsEqual(map1, map2 *sync.Map) bool { - equal := true - - map1.Range(func(k, v any) bool { - if _, ok := map2.Load(k); !ok { - equal = false - return false - } - return true - }) - map2.Range(func(k, v any) bool { - if _, ok := map1.Load(k); !ok { - equal = false - return false - } - return true - }) - - return equal -} - -func truePointer() *bool { - primitivePointersAreSilly := true - return &primitivePointersAreSilly -} diff --git a/pkg/ext-proc/backend/fake.go b/pkg/ext-proc/backend/fake.go index c45454975..63f20db60 100644 --- a/pkg/ext-proc/backend/fake.go +++ b/pkg/ext-proc/backend/fake.go @@ -8,16 +8,16 @@ import ( ) type FakePodMetricsClient struct { - Err map[Pod]error - Res map[Pod]*PodMetrics + Err map[string]error + Res map[string]*PodMetrics } func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) { - if err, ok := f.Err[pod]; ok { + if err, ok := f.Err[pod.Name]; ok { return nil, err } - klog.V(1).Infof("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod]) - return f.Res[pod], nil + klog.V(1).Infof("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod.Name]) + return f.Res[pod.Name], nil } type FakeDataStore struct { diff --git a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go index 5609ca532..117766b9c 100644 --- a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go +++ b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go @@ -146,3 +146,24 @@ func populateServiceMap(services ...*v1alpha1.InferenceModel) *sync.Map { } return returnVal } + +func mapsEqual(map1, map2 *sync.Map) bool { + equal := true + + map1.Range(func(k, v any) bool { + if _, ok := map2.Load(k); !ok { + equal = false + return false + } + return true + }) + map2.Range(func(k, v any) bool { + if _, ok := map1.Load(k); !ok { + equal = false + return false + } + return true + }) + + return equal +} diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index 35a41f8ff..0c2ae75f5 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -21,7 +21,6 @@ type InferencePoolReconciler struct { Record record.EventRecorder PoolNamespacedName types.NamespacedName Datastore *K8sDatastore - Zone string } func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { diff --git a/pkg/ext-proc/backend/provider.go b/pkg/ext-proc/backend/provider.go index 8bf672579..d6ccf85fa 100644 --- a/pkg/ext-proc/backend/provider.go +++ b/pkg/ext-proc/backend/provider.go @@ -3,11 +3,14 @@ package backend import ( "context" "fmt" + "math/rand" + "strconv" "sync" "time" "go.uber.org/multierr" logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" + corev1 "k8s.io/api/core/v1" klog "k8s.io/klog/v2" ) @@ -26,7 +29,8 @@ func NewProvider(pmc PodMetricsClient, datastore *K8sDatastore) *Provider { // Provider provides backend pods and information such as metrics. type Provider struct { - // key: Pod, value: *PodMetrics + // key: PodName, value: *PodMetrics + // TODO: change to use NamespacedName once we support multi-tenant inferencePools podMetrics sync.Map pmc PodMetricsClient datastore *K8sDatastore @@ -47,11 +51,11 @@ func (p *Provider) AllPodMetrics() []*PodMetrics { } func (p *Provider) UpdatePodMetrics(pod Pod, pm *PodMetrics) { - p.podMetrics.Store(pod, pm) + p.podMetrics.Store(pod.Name, pm) } func (p *Provider) GetPodMetrics(pod Pod) (*PodMetrics, bool) { - val, ok := p.podMetrics.Load(pod) + val, ok := p.podMetrics.Load(pod.Name) if ok { return val.(*PodMetrics), true } @@ -101,31 +105,70 @@ func (p *Provider) Init(refreshPodsInterval, refreshMetricsInterval time.Duratio // refreshPodsOnce lists pods and updates keys in the podMetrics map. // Note this function doesn't update the PodMetrics value, it's done separately. func (p *Provider) refreshPodsOnce() { - // merge new pods with cached ones. - // add new pod to the map - addNewPods := func(k, v any) bool { - pod := k.(Pod) - if _, ok := p.podMetrics.Load(pod); !ok { - new := &PodMetrics{ - Pod: pod, - Metrics: Metrics{ - ActiveModels: make(map[string]int), - }, - } - p.podMetrics.Store(pod, new) + pods, err := p.datastore.getPods() + if err != nil { + klog.V(logutil.DEFAULT).Infof("Couldn't list pods: %v", err) + p.podMetrics.Clear() + return + } + pool, _ := p.datastore.getInferencePool() + // revision is used to track which entries we need to remove in the next iteration that removes + // metrics for pods that don't exist anymore. Otherwise we have to build a map of the listed pods, + // which is not efficient. Revision can be any random id as long as it is different from the last + // refresh, so it should be very reliable (as reliable as the probability of randomly picking two + // different numbers from range 0 - maxInt). + revision := rand.Int() + ready := 0 + for _, pod := range pods { + if !podIsReady(pod) { + continue } - return true + // a ready pod + ready++ + if val, ok := p.podMetrics.Load(pod.Name); ok { + // pod already exists + pm := val.(*PodMetrics) + pm.revision = revision + continue + } + // new pod, add to the store for probing + new := &PodMetrics{ + Pod: Pod{ + Name: pod.Name, + Address: pod.Status.PodIP + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)), + }, + Metrics: Metrics{ + ActiveModels: make(map[string]int), + }, + revision: revision, + } + p.podMetrics.Store(pod.Name, new) } + + klog.V(logutil.DEFAULT).Infof("Pods in pool %s/%s with selector %v: total=%v ready=%v", + pool.Namespace, pool.Name, pool.Spec.Selector, len(pods), ready) + // remove pods that don't exist any more. mergeFn := func(k, v any) bool { - pod := k.(Pod) - if _, ok := p.datastore.pods.Load(pod); !ok { - p.podMetrics.Delete(pod) + pm := v.(*PodMetrics) + if pm.revision != revision { + p.podMetrics.Delete(pm.Pod.Name) } return true } p.podMetrics.Range(mergeFn) - p.datastore.pods.Range(addNewPods) +} + +func podIsReady(pod *corev1.Pod) bool { + if pod.DeletionTimestamp != nil { + return false + } + for _, condition := range pod.Status.Conditions { + if condition.Type == corev1.PodReady { + return condition.Status == corev1.ConditionTrue + } + } + return false } func (p *Provider) refreshMetricsOnce() error { @@ -141,8 +184,8 @@ func (p *Provider) refreshMetricsOnce() error { errCh := make(chan error) processOnePod := func(key, value any) bool { klog.V(logutil.TRACE).Infof("Processing pod %v and metric %v", key, value) - pod := key.(Pod) existing := value.(*PodMetrics) + pod := existing.Pod wg.Add(1) go func() { defer wg.Done() diff --git a/pkg/ext-proc/backend/provider_test.go b/pkg/ext-proc/backend/provider_test.go index ad231f575..9159ba481 100644 --- a/pkg/ext-proc/backend/provider_test.go +++ b/pkg/ext-proc/backend/provider_test.go @@ -2,17 +2,18 @@ package backend import ( "errors" - "sync" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" + testingutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" + corev1 "k8s.io/api/core/v1" ) var ( pod1 = &PodMetrics{ - Pod: Pod{Name: "pod1"}, + Pod: Pod{Name: "pod1", Address: "address1:9009"}, Metrics: Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -24,7 +25,7 @@ var ( }, } pod2 = &PodMetrics{ - Pod: Pod{Name: "pod2"}, + Pod: Pod{Name: "pod2", Address: "address2:9009"}, Metrics: Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.2, @@ -38,44 +39,67 @@ var ( ) func TestProvider(t *testing.T) { + allPodsLister := &testingutil.FakePodLister{ + PodsList: []*corev1.Pod{ + testingutil.MakePod(pod1.Pod.Name).SetReady().SetPodIP("address1").Obj(), + testingutil.MakePod(pod2.Pod.Name).SetReady().SetPodIP("address2").Obj(), + }, + } + allPodsMetricsClient := &FakePodMetricsClient{ + Res: map[string]*PodMetrics{ + pod1.Pod.Name: pod1, + pod2.Pod.Name: pod2, + }, + } + tests := []struct { - name string - pmc PodMetricsClient - datastore *K8sDatastore - initErr bool - want []*PodMetrics + name string + initPodMetrics []*PodMetrics + lister *testingutil.FakePodLister + pmc PodMetricsClient + step func(*Provider) + want []*PodMetrics }{ { - name: "Init success", - datastore: &K8sDatastore{ - pods: populateMap(pod1.Pod, pod2.Pod), + name: "Init without refreshing pods", + initPodMetrics: []*PodMetrics{pod1, pod2}, + lister: allPodsLister, + pmc: allPodsMetricsClient, + step: func(p *Provider) { + _ = p.refreshMetricsOnce() }, - pmc: &FakePodMetricsClient{ - Res: map[Pod]*PodMetrics{ - pod1.Pod: pod1, - pod2.Pod: pod2, - }, + want: []*PodMetrics{pod1, pod2}, + }, + { + name: "Fetching all success", + lister: allPodsLister, + pmc: allPodsMetricsClient, + step: func(p *Provider) { + p.refreshPodsOnce() + _ = p.refreshMetricsOnce() }, want: []*PodMetrics{pod1, pod2}, }, { - name: "Fetch metrics error", + name: "Fetch metrics error", + lister: allPodsLister, pmc: &FakePodMetricsClient{ - Err: map[Pod]error{ - pod2.Pod: errors.New("injected error"), + Err: map[string]error{ + pod2.Pod.Name: errors.New("injected error"), }, - Res: map[Pod]*PodMetrics{ - pod1.Pod: pod1, + Res: map[string]*PodMetrics{ + pod1.Pod.Name: pod1, }, }, - datastore: &K8sDatastore{ - pods: populateMap(pod1.Pod, pod2.Pod), + step: func(p *Provider) { + p.refreshPodsOnce() + _ = p.refreshMetricsOnce() }, want: []*PodMetrics{ pod1, // Failed to fetch pod2 metrics so it remains the default values. { - Pod: Pod{Name: "pod2"}, + Pod: pod2.Pod, Metrics: Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, @@ -85,30 +109,73 @@ func TestProvider(t *testing.T) { }, }, }, + { + name: "A new pod added", + initPodMetrics: []*PodMetrics{pod2}, + lister: allPodsLister, + pmc: allPodsMetricsClient, + step: func(p *Provider) { + p.refreshPodsOnce() + _ = p.refreshMetricsOnce() + }, + want: []*PodMetrics{pod1, pod2}, + }, + { + name: "A pod removed", + initPodMetrics: []*PodMetrics{pod1, pod2}, + lister: &testingutil.FakePodLister{ + PodsList: []*corev1.Pod{ + testingutil.MakePod(pod2.Pod.Name).SetReady().SetPodIP("address2").Obj(), + }, + }, + pmc: allPodsMetricsClient, + step: func(p *Provider) { + p.refreshPodsOnce() + _ = p.refreshMetricsOnce() + }, + want: []*PodMetrics{pod2}, + }, + { + name: "A pod removed, another added", + initPodMetrics: []*PodMetrics{pod1}, + lister: &testingutil.FakePodLister{ + PodsList: []*corev1.Pod{ + testingutil.MakePod(pod1.Pod.Name).SetReady().SetPodIP("address1").Obj(), + }, + }, + pmc: allPodsMetricsClient, + step: func(p *Provider) { + p.refreshPodsOnce() + _ = p.refreshMetricsOnce() + }, + want: []*PodMetrics{pod1}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - p := NewProvider(test.pmc, test.datastore) - err := p.Init(time.Millisecond, time.Millisecond) - if test.initErr != (err != nil) { - t.Fatalf("Unexpected error, got: %v, want: %v", err, test.initErr) + datastore := NewK8sDataStore(WithPodListerFactory( + func(pool *v1alpha1.InferencePool) *PodLister { + return &PodLister{ + Lister: test.lister, + } + })) + datastore.setInferencePool(&v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{TargetPortNumber: 9009}, + }) + p := NewProvider(test.pmc, datastore) + for _, m := range test.initPodMetrics { + p.UpdatePodMetrics(m.Pod, m) } + test.step(p) metrics := p.AllPodMetrics() lessFunc := func(a, b *PodMetrics) bool { return a.String() < b.String() } - if diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(lessFunc)); diff != "" { + if diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(lessFunc), + cmpopts.IgnoreFields(PodMetrics{}, "revision")); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) } } - -func populateMap(pods ...Pod) *sync.Map { - newMap := &sync.Map{} - for _, pod := range pods { - newMap.Store(pod, true) - } - return newMap -} diff --git a/pkg/ext-proc/backend/types.go b/pkg/ext-proc/backend/types.go index 7e399fedc..d375e4ec2 100644 --- a/pkg/ext-proc/backend/types.go +++ b/pkg/ext-proc/backend/types.go @@ -28,6 +28,7 @@ type Metrics struct { type PodMetrics struct { Pod Metrics + revision int } func (pm *PodMetrics) String() string { diff --git a/pkg/ext-proc/health.go b/pkg/ext-proc/health.go index 488851eb3..62527d06a 100644 --- a/pkg/ext-proc/health.go +++ b/pkg/ext-proc/health.go @@ -7,6 +7,7 @@ import ( healthPb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" + logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" klog "k8s.io/klog/v2" ) @@ -19,7 +20,7 @@ func (s *healthServer) Check(ctx context.Context, in *healthPb.HealthCheckReques klog.Infof("gRPC health check not serving: %s", in.String()) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_NOT_SERVING}, nil } - klog.Infof("gRPC health check serving: %s", in.String()) + klog.V(logutil.DEBUG).Infof("gRPC health check serving: %s", in.String()) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVING}, nil } diff --git a/pkg/ext-proc/main.go b/pkg/ext-proc/main.go index a783aa2c5..98b7e6cad 100644 --- a/pkg/ext-proc/main.go +++ b/pkg/ext-proc/main.go @@ -18,6 +18,7 @@ import ( runserver "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/server" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/kubernetes" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" "k8s.io/component-base/metrics/legacyregistry" @@ -53,14 +54,6 @@ var ( "poolNamespace", runserver.DefaultPoolNamespace, "Namespace of the InferencePool this Endpoint Picker is associated with.") - serviceName = flag.String( - "serviceName", - runserver.DefaultServiceName, - "Name of the Service that will be used to read EndpointSlices from") - zone = flag.String( - "zone", - runserver.DefaultZone, - "The zone that this instance is created in. Will be passed to the corresponding endpointSlice. ") refreshPodsInterval = flag.Duration( "refreshPodsInterval", runserver.DefaultRefreshPodsInterval, @@ -106,8 +99,6 @@ func main() { TargetEndpointKey: *targetEndpointKey, PoolName: *poolName, PoolNamespace: *poolNamespace, - ServiceName: *serviceName, - Zone: *zone, RefreshPodsInterval: *refreshPodsInterval, RefreshMetricsInterval: *refreshMetricsInterval, Scheme: scheme, @@ -116,12 +107,15 @@ func main() { } serverRunner.Setup() + k8sClient, err := kubernetes.NewForConfigAndClient(cfg, serverRunner.Manager.GetHTTPClient()) + if err != nil { + klog.Fatalf("Failed to create client: %v", err) + } + datastore.SetClient(k8sClient) + // Start health and ext-proc servers in goroutines healthSvr := startHealthServer(datastore, *grpcHealthPort) - extProcSvr := serverRunner.Start( - datastore, - &vllm.PodMetricsClientImpl{}, - ) + extProcSvr := serverRunner.Start(&vllm.PodMetricsClientImpl{}) // Start metrics handler metricsSvr := startMetricsHandler(*metricsPort, cfg) @@ -216,9 +210,5 @@ func validateFlags() error { return fmt.Errorf("required %q flag not set", "poolName") } - if *serviceName == "" { - return fmt.Errorf("required %q flag not set", "serviceName") - } - return nil } diff --git a/pkg/ext-proc/scheduling/filter_test.go b/pkg/ext-proc/scheduling/filter_test.go index d88f437c7..34731d152 100644 --- a/pkg/ext-proc/scheduling/filter_test.go +++ b/pkg/ext-proc/scheduling/filter_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" ) @@ -206,7 +207,7 @@ func TestFilter(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + if diff := cmp.Diff(test.output, got, cmpopts.IgnoreFields(backend.PodMetrics{}, "revision")); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -400,7 +401,7 @@ func TestFilterFunc(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got); diff != "" { + if diff := cmp.Diff(test.output, got, cmpopts.IgnoreFields(backend.PodMetrics{}, "revision")); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) diff --git a/pkg/ext-proc/server/runserver.go b/pkg/ext-proc/server/runserver.go index 1c9c1b2e2..981dab114 100644 --- a/pkg/ext-proc/server/runserver.go +++ b/pkg/ext-proc/server/runserver.go @@ -23,14 +23,12 @@ type ExtProcServerRunner struct { TargetEndpointKey string PoolName string PoolNamespace string - ServiceName string - Zone string RefreshPodsInterval time.Duration RefreshMetricsInterval time.Duration Scheme *runtime.Scheme Config *rest.Config Datastore *backend.K8sDatastore - manager ctrl.Manager + Manager ctrl.Manager } // Default values for CLI flags in main @@ -39,8 +37,6 @@ const ( DefaultTargetEndpointKey = "x-gateway-destination-endpoint" // default for --targetEndpointKey DefaultPoolName = "" // required but no default DefaultPoolNamespace = "default" // default for --poolNamespace - DefaultServiceName = "" // required but no default - DefaultZone = "" // default for --zone DefaultRefreshPodsInterval = 10 * time.Second // default for --refreshPodsInterval DefaultRefreshMetricsInterval = 50 * time.Millisecond // default for --refreshMetricsInterval ) @@ -51,22 +47,20 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { TargetEndpointKey: DefaultTargetEndpointKey, PoolName: DefaultPoolName, PoolNamespace: DefaultPoolNamespace, - ServiceName: DefaultServiceName, - Zone: DefaultZone, RefreshPodsInterval: DefaultRefreshPodsInterval, RefreshMetricsInterval: DefaultRefreshMetricsInterval, // Scheme, Config, and Datastore can be assigned later. } } -// Setup creates the reconcilers for pools, models, and endpointSlices and starts the manager. +// Setup creates the reconcilers for pools and models and starts the manager. func (r *ExtProcServerRunner) Setup() { // Create a new manager to manage controllers mgr, err := ctrl.NewManager(r.Config, ctrl.Options{Scheme: r.Scheme}) if err != nil { klog.Fatalf("Failed to create controller manager: %v", err) } - r.manager = mgr + r.Manager = mgr // Create the controllers and register them with the manager if err := (&backend.InferencePoolReconciler{ @@ -94,22 +88,10 @@ func (r *ExtProcServerRunner) Setup() { }).SetupWithManager(mgr); err != nil { klog.Fatalf("Failed setting up InferenceModelReconciler: %v", err) } - - if err := (&backend.EndpointSliceReconciler{ - Datastore: r.Datastore, - Scheme: mgr.GetScheme(), - Client: mgr.GetClient(), - Record: mgr.GetEventRecorderFor("endpointslice"), - ServiceName: r.ServiceName, - Zone: r.Zone, - }).SetupWithManager(mgr); err != nil { - klog.Fatalf("Failed setting up EndpointSliceReconciler: %v", err) - } } // Start starts the Envoy external processor server in a goroutine. func (r *ExtProcServerRunner) Start( - podDatastore *backend.K8sDatastore, podMetricsClient backend.PodMetricsClient, ) *grpc.Server { svr := grpc.NewServer() @@ -122,7 +104,7 @@ func (r *ExtProcServerRunner) Start( klog.Infof("Ext-proc server listening on port: %d", r.GrpcPort) // Initialize backend provider - pp := backend.NewProvider(podMetricsClient, podDatastore) + pp := backend.NewProvider(podMetricsClient, r.Datastore) if err := pp.Init(r.RefreshPodsInterval, r.RefreshMetricsInterval); err != nil { klog.Fatalf("Failed to initialize backend provider: %v", err) } @@ -143,13 +125,12 @@ func (r *ExtProcServerRunner) Start( } func (r *ExtProcServerRunner) StartManager() { - if r.manager == nil { + if r.Manager == nil { klog.Fatalf("Runner has no manager setup to run: %v", r) } // Start the controller manager. Blocking and will return when shutdown is complete. klog.Infof("Starting controller manager") - mgr := r.manager - if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { + if err := r.Manager.Start(ctrl.SetupSignalHandler()); err != nil { klog.Fatalf("Error starting controller manager: %v", err) } klog.Info("Controller manager shutting down") diff --git a/pkg/ext-proc/test/utils.go b/pkg/ext-proc/test/utils.go index 63972849e..a9dc4efa0 100644 --- a/pkg/ext-proc/test/utils.go +++ b/pkg/ext-proc/test/utils.go @@ -18,13 +18,13 @@ import ( func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Duration, pods []*backend.PodMetrics, models map[string]*v1alpha1.InferenceModel) *grpc.Server { ps := make(backend.PodSet) - pms := make(map[backend.Pod]*backend.PodMetrics) + pms := make(map[string]*backend.PodMetrics) for _, pod := range pods { ps[pod.Pod] = true - pms[pod.Pod] = pod + pms[pod.Pod.Name] = pod } pmc := &backend.FakePodMetricsClient{Res: pms} - pp := backend.NewProvider(pmc, backend.NewK8sDataStore(backend.WithPods(pods))) + pp := backend.NewProvider(pmc, backend.NewK8sDataStore()) if err := pp.Init(refreshPodsInterval, refreshMetricsInterval); err != nil { klog.Fatalf("failed to initialize: %v", err) } diff --git a/pkg/ext-proc/util/testing/lister.go b/pkg/ext-proc/util/testing/lister.go new file mode 100644 index 000000000..023f30a1d --- /dev/null +++ b/pkg/ext-proc/util/testing/lister.go @@ -0,0 +1,19 @@ +package testing + +import ( + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/labels" + listersv1 "k8s.io/client-go/listers/core/v1" +) + +type FakePodLister struct { + PodsList []*v1.Pod +} + +func (l *FakePodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { + return l.PodsList, nil +} + +func (l *FakePodLister) Pods(namespace string) listersv1.PodNamespaceLister { + panic("not implemented") +} diff --git a/pkg/ext-proc/util/testing/wrappers.go b/pkg/ext-proc/util/testing/wrappers.go new file mode 100644 index 000000000..7b593bbd9 --- /dev/null +++ b/pkg/ext-proc/util/testing/wrappers.go @@ -0,0 +1,38 @@ +package testing + +import ( + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// PodWrapper wraps a Pod inside. +type PodWrapper struct{ corev1.Pod } + +// MakePod creates a Pod wrapper. +func MakePod(name string) *PodWrapper { + return &PodWrapper{ + corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + }, + } +} + +// Obj returns the inner Pod. +func (p *PodWrapper) Obj() *corev1.Pod { + return &p.Pod +} + +func (p *PodWrapper) SetReady() *PodWrapper { + p.Status.Conditions = []corev1.PodCondition{{ + Type: corev1.PodReady, + Status: corev1.ConditionTrue, + }} + return p +} + +func (p *PodWrapper) SetPodIP(podIP string) *PodWrapper { + p.Status.PodIP = podIP + return p +} diff --git a/pkg/manifests/ext_proc.yaml b/pkg/manifests/ext_proc.yaml index 410c31ed6..2ec93fd2b 100644 --- a/pkg/manifests/ext_proc.yaml +++ b/pkg/manifests/ext_proc.yaml @@ -77,8 +77,6 @@ spec: - "vllm-llama2-7b-pool" - -v - "3" - - -serviceName - - "vllm-llama2-7b-pool" - -grpcPort - "9002" - -grpcHealthPort diff --git a/pkg/manifests/vllm/deployment.yaml b/pkg/manifests/vllm/deployment.yaml index 4af0891d7..1f5073e98 100644 --- a/pkg/manifests/vllm/deployment.yaml +++ b/pkg/manifests/vllm/deployment.yaml @@ -1,16 +1,3 @@ -apiVersion: v1 -kind: Service -metadata: - name: vllm-llama2-7b-pool -spec: - selector: - app: vllm-llama2-7b-pool - ports: - - protocol: TCP - port: 8000 - targetPort: 8000 - type: ClusterIP ---- apiVersion: apps/v1 kind: Deployment metadata: diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index c2c1ea928..019e858a2 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -245,11 +245,6 @@ func createModelServer(k8sClient client.Client, secretPath, deployPath string) { // Wait for the deployment to be available. testutils.DeploymentAvailable(ctx, k8sClient, deploy, modelReadyTimeout, interval) - - // Wait for the service to exist. - testutils.EventuallyExists(ctx, func() error { - return k8sClient.Get(ctx, types.NamespacedName{Namespace: nsName, Name: modelServerName}, &corev1.Service{}) - }, existsTimeout, interval) } // createEnvoy creates the envoy proxy resources used for testing from the given filePath. diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index 95ad49081..3dfe28f7c 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -12,6 +12,7 @@ import ( "log" "os" "path/filepath" + "strconv" "testing" "time" @@ -26,6 +27,8 @@ import ( "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" runserver "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/server" extprocutils "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/test" + testingutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" k8syaml "k8s.io/apimachinery/pkg/util/yaml" @@ -113,7 +116,7 @@ func SKIPTestHandleRequestBody(t *testing.T) { { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("address-1"), + RawValue: []byte("pod-1:8000"), }, }, { @@ -179,7 +182,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("address-1"), + RawValue: []byte("pod-1:8000"), }, }, { @@ -193,7 +196,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { Fields: map[string]*structpb.Value{ runserver.DefaultTargetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: "address-1", + StringValue: "pod-1:8000", }, }, }, @@ -203,47 +206,38 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, } - pods := []*backend.PodMetrics{ + metrics := []*backend.Metrics{ { - Pod: extprocutils.FakePod(0), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, }, }, { - Pod: extprocutils.FakePod(1), - Metrics: backend.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.1, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg2": 1, - }, + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, }, }, { - Pod: extprocutils.FakePod(2), - Metrics: backend.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - }, + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, }, }, } // Set up global k8sclient and extproc server runner with test environment config - BeforeSuit() + podMetrics := BeforeSuit(metrics) for _, test := range tests { t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer(t, pods) + client, cleanup := setUpHermeticServer(t, podMetrics) t.Cleanup(cleanup) want := &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_RequestBody{ @@ -324,8 +318,8 @@ func setUpHermeticServer(t *testing.T, pods []*backend.PodMetrics) (client extPr } } } + inferencePool := &v1alpha1.InferencePool{} for _, doc := range docs { - inferencePool := &v1alpha1.InferencePool{} if err = yaml.Unmarshal(doc, inferencePool); err != nil { log.Fatalf("Can't unmarshal object: %v", doc) } @@ -334,18 +328,19 @@ func setUpHermeticServer(t *testing.T, pods []*backend.PodMetrics) (client extPr if err := k8sClient.Create(context.Background(), inferencePool); err != nil { log.Fatalf("unable to create inferencePool %v: %v", inferencePool.Name, err) } + // expecting a single inferencepool + break } } ps := make(backend.PodSet) - pms := make(map[backend.Pod]*backend.PodMetrics) + pms := make(map[string]*backend.PodMetrics) for _, pod := range pods { ps[pod.Pod] = true - pms[pod.Pod] = pod + pms[pod.Pod.Name] = pod } pmc := &backend.FakePodMetricsClient{Res: pms} - - server := serverRunner.Start(backend.NewK8sDataStore(backend.WithPods(pods)), pmc) + server := serverRunner.Start(pmc) if err != nil { log.Fatalf("Ext-proc failed with the err: %v", err) } @@ -373,7 +368,7 @@ func setUpHermeticServer(t *testing.T, pods []*backend.PodMetrics) (client extPr } // Sets up a test environment and returns the runner struct -func BeforeSuit() { +func BeforeSuit(metrics []*backend.Metrics) []*backend.PodMetrics { // Set up mock k8s API Client testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, @@ -395,12 +390,35 @@ func BeforeSuit() { log.Fatalf("No error, but returned kubernetes client is nil, cfg: %v", cfg) } + podMetrics := []*backend.PodMetrics{} + fakeLister := &testingutil.FakePodLister{ + PodsList: []*corev1.Pod{}, + } + for i, m := range metrics { + podName := "pod-" + strconv.Itoa(i) + pod := testingutil.MakePod(podName).SetReady().SetPodIP(podName).Obj() + fakeLister.PodsList = append(fakeLister.PodsList, pod) + podMetrics = append(podMetrics, &backend.PodMetrics{ + Pod: backend.Pod{ + Name: pod.Name, + Address: pod.Status.PodIP + ":8000", + }, + Metrics: *m, + }) + } + serverRunner = runserver.NewDefaultExtProcServerRunner() // Adjust from defaults serverRunner.PoolName = "vllm-llama2-7b-pool" serverRunner.Scheme = scheme serverRunner.Config = cfg - serverRunner.Datastore = backend.NewK8sDataStore() + serverRunner.Datastore = backend.NewK8sDataStore(backend.WithPodListerFactory( + func(pool *v1alpha1.InferencePool) *backend.PodLister { + klog.V(1).Infof("Setting the fake lister %v", len(fakeLister.PodsList)) + return &backend.PodLister{ + Lister: fakeLister, + } + })) serverRunner.Setup() @@ -408,6 +426,10 @@ func BeforeSuit() { go func() { serverRunner.StartManager() }() + + // Wait the reconcilers to populate the datastore. + time.Sleep(5 * time.Second) + return podMetrics } func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) {