diff --git a/.golangci.yml b/.golangci.yml index 2ad3b93da..1462bcc77 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -14,6 +14,7 @@ linters: - dupword - durationcheck - fatcontext + - gci - ginkgolinter - gocritic - govet diff --git a/pkg/ext-proc/backend/datastore.go b/pkg/ext-proc/backend/datastore.go index 627ddbe52..b466a2ed5 100644 --- a/pkg/ext-proc/backend/datastore.go +++ b/pkg/ext-proc/backend/datastore.go @@ -1,26 +1,13 @@ package backend import ( - "context" "errors" "math/rand" "sync" - "time" - "github.com/google/go-cmp/cmp" "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" corev1 "k8s.io/api/core/v1" - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/meta" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/labels" - "k8s.io/client-go/informers" - informersv1 "k8s.io/client-go/informers/core/v1" - "k8s.io/client-go/kubernetes" - clientset "k8s.io/client-go/kubernetes" - listersv1 "k8s.io/client-go/listers/core/v1" - "k8s.io/client-go/tools/cache" "k8s.io/klog/v2" ) @@ -28,9 +15,8 @@ func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore { store := &K8sDatastore{ poolMu: sync.RWMutex{}, InferenceModels: &sync.Map{}, + pods: &sync.Map{}, } - - store.podListerFactory = store.createPodLister for _, opt := range options { opt(store) } @@ -39,68 +25,29 @@ func NewK8sDataStore(options ...K8sDatastoreOption) *K8sDatastore { // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) type K8sDatastore struct { - client kubernetes.Interface // poolMu is used to synchronize access to the inferencePool. - poolMu sync.RWMutex - inferencePool *v1alpha1.InferencePool - podListerFactory PodListerFactory - podLister *PodLister - InferenceModels *sync.Map + poolMu sync.RWMutex + inferencePool *v1alpha1.InferencePool + InferenceModels *sync.Map + pods *sync.Map } type K8sDatastoreOption func(*K8sDatastore) -type PodListerFactory func(*v1alpha1.InferencePool) *PodLister // WithPods can be used in tests to override the pods. -func WithPodListerFactory(factory PodListerFactory) K8sDatastoreOption { +func WithPods(pods []*PodMetrics) K8sDatastoreOption { return func(store *K8sDatastore) { - store.podListerFactory = factory + store.pods = &sync.Map{} + for _, pod := range pods { + store.pods.Store(pod.Pod, true) + } } } -type PodLister struct { - Lister listersv1.PodLister - sharedInformer informers.SharedInformerFactory -} - -func (l *PodLister) listEverything() ([]*corev1.Pod, error) { - return l.Lister.List(labels.Everything()) - -} - -func (ds *K8sDatastore) SetClient(client kubernetes.Interface) { - ds.client = client -} - func (ds *K8sDatastore) setInferencePool(pool *v1alpha1.InferencePool) { ds.poolMu.Lock() defer ds.poolMu.Unlock() - - if ds.inferencePool != nil && cmp.Equal(ds.inferencePool.Spec.Selector, pool.Spec.Selector) { - // Pool updated, but the selector stayed the same, so no need to change the informer. - ds.inferencePool = pool - return - } - - // New pool or selector updated. ds.inferencePool = pool - - if ds.podLister != nil && ds.podLister.sharedInformer != nil { - // Shutdown the old informer async since this takes a few seconds. - go func() { - ds.podLister.sharedInformer.Shutdown() - }() - } - - if ds.podListerFactory != nil { - // Create a new informer with the new selector. - ds.podLister = ds.podListerFactory(ds.inferencePool) - if ds.podLister != nil && ds.podLister.sharedInformer != nil { - ctx := context.Background() - ds.podLister.sharedInformer.Start(ctx.Done()) - ds.podLister.sharedInformer.WaitForCacheSync(ctx.Done()) - } - } } func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) { @@ -112,58 +59,13 @@ func (ds *K8sDatastore) getInferencePool() (*v1alpha1.InferencePool, error) { return ds.inferencePool, nil } -func (ds *K8sDatastore) createPodLister(pool *v1alpha1.InferencePool) *PodLister { - if ds.client == nil { - return nil - } - klog.V(logutil.DEFAULT).Infof("Creating informer for pool %v", pool.Name) - selectorSet := make(map[string]string) - for k, v := range pool.Spec.Selector { - selectorSet[string(k)] = string(v) - } - - newPodInformer := func(cs clientset.Interface, resyncPeriod time.Duration) cache.SharedIndexInformer { - informer := informersv1.NewFilteredPodInformer(cs, pool.Namespace, resyncPeriod, cache.Indexers{}, func(options *metav1.ListOptions) { - options.LabelSelector = labels.SelectorFromSet(selectorSet).String() - }) - err := informer.SetTransform(func(obj interface{}) (interface{}, error) { - // Remove unnecessary fields to improve memory footprint. - if accessor, err := meta.Accessor(obj); err == nil { - if accessor.GetManagedFields() != nil { - accessor.SetManagedFields(nil) - } - } - return obj, nil - }) - if err != nil { - klog.Errorf("Failed to set pod transformer: %v", err) - } - return informer - } - // 0 means we disable resyncing, it is not really useful to resync every hour (the controller-runtime default), - // if things go wrong in the watch, no one will wait for an hour for things to get fixed. - // As precedence, kube-scheduler also disables this since it is expensive to list all pods from the api-server regularly. - resyncPeriod := time.Duration(0) - sharedInformer := informers.NewSharedInformerFactory(ds.client, resyncPeriod) - sharedInformer.InformerFor(&v1.Pod{}, newPodInformer) - - return &PodLister{ - Lister: sharedInformer.Core().V1().Pods().Lister(), - sharedInformer: sharedInformer, - } -} - -func (ds *K8sDatastore) getPods() ([]*corev1.Pod, error) { - ds.poolMu.RLock() - defer ds.poolMu.RUnlock() - if !ds.HasSynced() { - return nil, errors.New("InferencePool is not initialized in datastore") - } - pods, err := ds.podLister.listEverything() - if err != nil { - return nil, err - } - return pods, nil +func (ds *K8sDatastore) GetPodIPs() []string { + var ips []string + ds.pods.Range(func(name, pod any) bool { + ips = append(ips, pod.(*corev1.Pod).Status.PodIP) + return true + }) + return ips } func (s *K8sDatastore) FetchModelData(modelName string) (returnModel *v1alpha1.InferenceModel) { diff --git a/pkg/ext-proc/backend/endpointslice_reconciler.go b/pkg/ext-proc/backend/endpointslice_reconciler.go new file mode 100644 index 000000000..a2a9790f2 --- /dev/null +++ b/pkg/ext-proc/backend/endpointslice_reconciler.go @@ -0,0 +1,109 @@ +package backend + +import ( + "context" + "strconv" + "time" + + "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" + logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" + discoveryv1 "k8s.io/api/discovery/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/record" + klog "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/predicate" +) + +var ( + serviceOwnerLabel = "kubernetes.io/service-name" +) + +type EndpointSliceReconciler struct { + client.Client + Scheme *runtime.Scheme + Record record.EventRecorder + ServiceName string + Zone string + Datastore *K8sDatastore +} + +func (c *EndpointSliceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + inferencePool, err := c.Datastore.getInferencePool() + if err != nil { + klog.V(logutil.DEFAULT).Infof("Skipping reconciling EndpointSlice because the InferencePool is not available yet: %v", err) + return ctrl.Result{Requeue: true, RequeueAfter: time.Second}, nil + } + + klog.V(logutil.DEFAULT).Info("Reconciling EndpointSlice ", req.NamespacedName) + + endpointSlice := &discoveryv1.EndpointSlice{} + if err := c.Get(ctx, req.NamespacedName, endpointSlice); err != nil { + klog.Errorf("Unable to get EndpointSlice: %v", err) + return ctrl.Result{}, err + } + c.updateDatastore(endpointSlice, inferencePool) + + return ctrl.Result{}, nil +} + +// TODO: Support multiple endpointslices for a single service +func (c *EndpointSliceReconciler) updateDatastore( + slice *discoveryv1.EndpointSlice, + inferencePool *v1alpha1.InferencePool) { + podMap := make(map[Pod]bool) + + for _, endpoint := range slice.Endpoints { + klog.V(logutil.DEFAULT).Infof("Zone: %v \n endpoint: %+v \n", c.Zone, endpoint) + if c.validPod(endpoint) { + pod := Pod{ + Name: endpoint.TargetRef.Name, + Address: endpoint.Addresses[0] + ":" + strconv.Itoa(int(inferencePool.Spec.TargetPortNumber)), + } + podMap[pod] = true + klog.V(logutil.DEFAULT).Infof("Storing pod %v", pod) + c.Datastore.pods.Store(pod, true) + } + } + + removeOldPods := func(k, v any) bool { + pod, ok := k.(Pod) + if !ok { + klog.Errorf("Unable to cast key to Pod: %v", k) + return false + } + if _, ok := podMap[pod]; !ok { + klog.V(logutil.DEFAULT).Infof("Removing pod %v", pod) + c.Datastore.pods.Delete(pod) + } + return true + } + c.Datastore.pods.Range(removeOldPods) +} + +func (c *EndpointSliceReconciler) SetupWithManager(mgr ctrl.Manager) error { + ownsEndPointSlice := func(object client.Object) bool { + // Check if the object is an EndpointSlice + endpointSlice, ok := object.(*discoveryv1.EndpointSlice) + if !ok { + return false + } + + gotLabel := endpointSlice.ObjectMeta.Labels[serviceOwnerLabel] + wantLabel := c.ServiceName + return gotLabel == wantLabel + } + + return ctrl.NewControllerManagedBy(mgr). + For(&discoveryv1.EndpointSlice{}, + builder.WithPredicates(predicate.NewPredicateFuncs(ownsEndPointSlice))). + Complete(c) +} + +func (c *EndpointSliceReconciler) validPod(endpoint discoveryv1.Endpoint) bool { + validZone := c.Zone == "" || c.Zone != "" && *endpoint.Zone == c.Zone + return validZone && *endpoint.Conditions.Ready + +} diff --git a/pkg/ext-proc/backend/endpointslice_reconcilier_test.go b/pkg/ext-proc/backend/endpointslice_reconcilier_test.go new file mode 100644 index 000000000..e3c927ba8 --- /dev/null +++ b/pkg/ext-proc/backend/endpointslice_reconcilier_test.go @@ -0,0 +1,202 @@ +package backend + +import ( + "sync" + "testing" + + "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" + v1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" +) + +var ( + basePod1 = Pod{Name: "pod1"} + basePod2 = Pod{Name: "pod2"} + basePod3 = Pod{Name: "pod3"} +) + +func TestUpdateDatastore_EndpointSliceReconciler(t *testing.T) { + tests := []struct { + name string + datastore *K8sDatastore + incomingSlice *discoveryv1.EndpointSlice + wantPods *sync.Map + }{ + { + name: "Add new pod", + datastore: &K8sDatastore{ + pods: populateMap(basePod1, basePod2), + inferencePool: &v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{ + TargetPortNumber: int32(8000), + }, + }, + }, + incomingSlice: &discoveryv1.EndpointSlice{ + Endpoints: []discoveryv1.Endpoint{ + { + TargetRef: &v1.ObjectReference{ + Name: "pod1", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + { + TargetRef: &v1.ObjectReference{ + Name: "pod2", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + { + TargetRef: &v1.ObjectReference{ + Name: "pod3", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + }, + }, + wantPods: populateMap(basePod1, basePod2, basePod3), + }, + { + name: "New pod, but its not ready yet. Do not add.", + datastore: &K8sDatastore{ + pods: populateMap(basePod1, basePod2), + inferencePool: &v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{ + TargetPortNumber: int32(8000), + }, + }, + }, + incomingSlice: &discoveryv1.EndpointSlice{ + Endpoints: []discoveryv1.Endpoint{ + { + TargetRef: &v1.ObjectReference{ + Name: "pod1", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + { + TargetRef: &v1.ObjectReference{ + Name: "pod2", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + { + TargetRef: &v1.ObjectReference{ + Name: "pod3", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: new(bool), + }, + Addresses: []string{"0.0.0.0"}, + }, + }, + }, + wantPods: populateMap(basePod1, basePod2), + }, + { + name: "Existing pod not ready, new pod added, and is ready", + datastore: &K8sDatastore{ + pods: populateMap(basePod1, basePod2), + inferencePool: &v1alpha1.InferencePool{ + Spec: v1alpha1.InferencePoolSpec{ + TargetPortNumber: int32(8000), + }, + }, + }, + incomingSlice: &discoveryv1.EndpointSlice{ + Endpoints: []discoveryv1.Endpoint{ + { + TargetRef: &v1.ObjectReference{ + Name: "pod1", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: new(bool), + }, + Addresses: []string{"0.0.0.0"}, + }, + { + TargetRef: &v1.ObjectReference{ + Name: "pod2", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + { + TargetRef: &v1.ObjectReference{ + Name: "pod3", + }, + Zone: new(string), + Conditions: discoveryv1.EndpointConditions{ + Ready: truePointer(), + }, + Addresses: []string{"0.0.0.0"}, + }, + }, + }, + wantPods: populateMap(basePod3, basePod2), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + endpointSliceReconciler := &EndpointSliceReconciler{Datastore: test.datastore, Zone: ""} + endpointSliceReconciler.updateDatastore(test.incomingSlice, test.datastore.inferencePool) + + if mapsEqual(endpointSliceReconciler.Datastore.pods, test.wantPods) { + t.Errorf("Unexpected output pod mismatch. \n Got %v \n Want: %v \n", + endpointSliceReconciler.Datastore.pods, + test.wantPods) + } + }) + } +} + +func mapsEqual(map1, map2 *sync.Map) bool { + equal := true + + map1.Range(func(k, v any) bool { + if _, ok := map2.Load(k); !ok { + equal = false + return false + } + return true + }) + map2.Range(func(k, v any) bool { + if _, ok := map1.Load(k); !ok { + equal = false + return false + } + return true + }) + + return equal +} + +func truePointer() *bool { + primitivePointersAreSilly := true + return &primitivePointersAreSilly +} diff --git a/pkg/ext-proc/backend/fake.go b/pkg/ext-proc/backend/fake.go index 63f20db60..c45454975 100644 --- a/pkg/ext-proc/backend/fake.go +++ b/pkg/ext-proc/backend/fake.go @@ -8,16 +8,16 @@ import ( ) type FakePodMetricsClient struct { - Err map[string]error - Res map[string]*PodMetrics + Err map[Pod]error + Res map[Pod]*PodMetrics } func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod Pod, existing *PodMetrics) (*PodMetrics, error) { - if err, ok := f.Err[pod.Name]; ok { + if err, ok := f.Err[pod]; ok { return nil, err } - klog.V(1).Infof("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod.Name]) - return f.Res[pod.Name], nil + klog.V(1).Infof("pod: %+v\n existing: %+v \n new: %+v \n", pod, existing, f.Res[pod]) + return f.Res[pod], nil } type FakeDataStore struct { diff --git a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go index 117766b9c..5609ca532 100644 --- a/pkg/ext-proc/backend/inferencemodel_reconciler_test.go +++ b/pkg/ext-proc/backend/inferencemodel_reconciler_test.go @@ -146,24 +146,3 @@ func populateServiceMap(services ...*v1alpha1.InferenceModel) *sync.Map { } return returnVal } - -func mapsEqual(map1, map2 *sync.Map) bool { - equal := true - - map1.Range(func(k, v any) bool { - if _, ok := map2.Load(k); !ok { - equal = false - return false - } - return true - }) - map2.Range(func(k, v any) bool { - if _, ok := map1.Load(k); !ok { - equal = false - return false - } - return true - }) - - return equal -} diff --git a/pkg/ext-proc/backend/inferencepool_reconciler.go b/pkg/ext-proc/backend/inferencepool_reconciler.go index 0c2ae75f5..35a41f8ff 100644 --- a/pkg/ext-proc/backend/inferencepool_reconciler.go +++ b/pkg/ext-proc/backend/inferencepool_reconciler.go @@ -21,6 +21,7 @@ type InferencePoolReconciler struct { Record record.EventRecorder PoolNamespacedName types.NamespacedName Datastore *K8sDatastore + Zone string } func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { diff --git a/pkg/ext-proc/backend/provider.go b/pkg/ext-proc/backend/provider.go index d6ccf85fa..8bf672579 100644 --- a/pkg/ext-proc/backend/provider.go +++ b/pkg/ext-proc/backend/provider.go @@ -3,14 +3,11 @@ package backend import ( "context" "fmt" - "math/rand" - "strconv" "sync" "time" "go.uber.org/multierr" logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" - corev1 "k8s.io/api/core/v1" klog "k8s.io/klog/v2" ) @@ -29,8 +26,7 @@ func NewProvider(pmc PodMetricsClient, datastore *K8sDatastore) *Provider { // Provider provides backend pods and information such as metrics. type Provider struct { - // key: PodName, value: *PodMetrics - // TODO: change to use NamespacedName once we support multi-tenant inferencePools + // key: Pod, value: *PodMetrics podMetrics sync.Map pmc PodMetricsClient datastore *K8sDatastore @@ -51,11 +47,11 @@ func (p *Provider) AllPodMetrics() []*PodMetrics { } func (p *Provider) UpdatePodMetrics(pod Pod, pm *PodMetrics) { - p.podMetrics.Store(pod.Name, pm) + p.podMetrics.Store(pod, pm) } func (p *Provider) GetPodMetrics(pod Pod) (*PodMetrics, bool) { - val, ok := p.podMetrics.Load(pod.Name) + val, ok := p.podMetrics.Load(pod) if ok { return val.(*PodMetrics), true } @@ -105,70 +101,31 @@ func (p *Provider) Init(refreshPodsInterval, refreshMetricsInterval time.Duratio // refreshPodsOnce lists pods and updates keys in the podMetrics map. // Note this function doesn't update the PodMetrics value, it's done separately. func (p *Provider) refreshPodsOnce() { - pods, err := p.datastore.getPods() - if err != nil { - klog.V(logutil.DEFAULT).Infof("Couldn't list pods: %v", err) - p.podMetrics.Clear() - return - } - pool, _ := p.datastore.getInferencePool() - // revision is used to track which entries we need to remove in the next iteration that removes - // metrics for pods that don't exist anymore. Otherwise we have to build a map of the listed pods, - // which is not efficient. Revision can be any random id as long as it is different from the last - // refresh, so it should be very reliable (as reliable as the probability of randomly picking two - // different numbers from range 0 - maxInt). - revision := rand.Int() - ready := 0 - for _, pod := range pods { - if !podIsReady(pod) { - continue - } - // a ready pod - ready++ - if val, ok := p.podMetrics.Load(pod.Name); ok { - // pod already exists - pm := val.(*PodMetrics) - pm.revision = revision - continue - } - // new pod, add to the store for probing - new := &PodMetrics{ - Pod: Pod{ - Name: pod.Name, - Address: pod.Status.PodIP + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)), - }, - Metrics: Metrics{ - ActiveModels: make(map[string]int), - }, - revision: revision, + // merge new pods with cached ones. + // add new pod to the map + addNewPods := func(k, v any) bool { + pod := k.(Pod) + if _, ok := p.podMetrics.Load(pod); !ok { + new := &PodMetrics{ + Pod: pod, + Metrics: Metrics{ + ActiveModels: make(map[string]int), + }, + } + p.podMetrics.Store(pod, new) } - p.podMetrics.Store(pod.Name, new) + return true } - - klog.V(logutil.DEFAULT).Infof("Pods in pool %s/%s with selector %v: total=%v ready=%v", - pool.Namespace, pool.Name, pool.Spec.Selector, len(pods), ready) - // remove pods that don't exist any more. mergeFn := func(k, v any) bool { - pm := v.(*PodMetrics) - if pm.revision != revision { - p.podMetrics.Delete(pm.Pod.Name) + pod := k.(Pod) + if _, ok := p.datastore.pods.Load(pod); !ok { + p.podMetrics.Delete(pod) } return true } p.podMetrics.Range(mergeFn) -} - -func podIsReady(pod *corev1.Pod) bool { - if pod.DeletionTimestamp != nil { - return false - } - for _, condition := range pod.Status.Conditions { - if condition.Type == corev1.PodReady { - return condition.Status == corev1.ConditionTrue - } - } - return false + p.datastore.pods.Range(addNewPods) } func (p *Provider) refreshMetricsOnce() error { @@ -184,8 +141,8 @@ func (p *Provider) refreshMetricsOnce() error { errCh := make(chan error) processOnePod := func(key, value any) bool { klog.V(logutil.TRACE).Infof("Processing pod %v and metric %v", key, value) + pod := key.(Pod) existing := value.(*PodMetrics) - pod := existing.Pod wg.Add(1) go func() { defer wg.Done() diff --git a/pkg/ext-proc/backend/provider_test.go b/pkg/ext-proc/backend/provider_test.go index 9159ba481..ad231f575 100644 --- a/pkg/ext-proc/backend/provider_test.go +++ b/pkg/ext-proc/backend/provider_test.go @@ -2,18 +2,17 @@ package backend import ( "errors" + "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "inference.networking.x-k8s.io/gateway-api-inference-extension/api/v1alpha1" - testingutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" - corev1 "k8s.io/api/core/v1" ) var ( pod1 = &PodMetrics{ - Pod: Pod{Name: "pod1", Address: "address1:9009"}, + Pod: Pod{Name: "pod1"}, Metrics: Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -25,7 +24,7 @@ var ( }, } pod2 = &PodMetrics{ - Pod: Pod{Name: "pod2", Address: "address2:9009"}, + Pod: Pod{Name: "pod2"}, Metrics: Metrics{ WaitingQueueSize: 1, KVCacheUsagePercent: 0.2, @@ -39,67 +38,44 @@ var ( ) func TestProvider(t *testing.T) { - allPodsLister := &testingutil.FakePodLister{ - PodsList: []*corev1.Pod{ - testingutil.MakePod(pod1.Pod.Name).SetReady().SetPodIP("address1").Obj(), - testingutil.MakePod(pod2.Pod.Name).SetReady().SetPodIP("address2").Obj(), - }, - } - allPodsMetricsClient := &FakePodMetricsClient{ - Res: map[string]*PodMetrics{ - pod1.Pod.Name: pod1, - pod2.Pod.Name: pod2, - }, - } - tests := []struct { - name string - initPodMetrics []*PodMetrics - lister *testingutil.FakePodLister - pmc PodMetricsClient - step func(*Provider) - want []*PodMetrics + name string + pmc PodMetricsClient + datastore *K8sDatastore + initErr bool + want []*PodMetrics }{ { - name: "Init without refreshing pods", - initPodMetrics: []*PodMetrics{pod1, pod2}, - lister: allPodsLister, - pmc: allPodsMetricsClient, - step: func(p *Provider) { - _ = p.refreshMetricsOnce() + name: "Init success", + datastore: &K8sDatastore{ + pods: populateMap(pod1.Pod, pod2.Pod), }, - want: []*PodMetrics{pod1, pod2}, - }, - { - name: "Fetching all success", - lister: allPodsLister, - pmc: allPodsMetricsClient, - step: func(p *Provider) { - p.refreshPodsOnce() - _ = p.refreshMetricsOnce() + pmc: &FakePodMetricsClient{ + Res: map[Pod]*PodMetrics{ + pod1.Pod: pod1, + pod2.Pod: pod2, + }, }, want: []*PodMetrics{pod1, pod2}, }, { - name: "Fetch metrics error", - lister: allPodsLister, + name: "Fetch metrics error", pmc: &FakePodMetricsClient{ - Err: map[string]error{ - pod2.Pod.Name: errors.New("injected error"), + Err: map[Pod]error{ + pod2.Pod: errors.New("injected error"), }, - Res: map[string]*PodMetrics{ - pod1.Pod.Name: pod1, + Res: map[Pod]*PodMetrics{ + pod1.Pod: pod1, }, }, - step: func(p *Provider) { - p.refreshPodsOnce() - _ = p.refreshMetricsOnce() + datastore: &K8sDatastore{ + pods: populateMap(pod1.Pod, pod2.Pod), }, want: []*PodMetrics{ pod1, // Failed to fetch pod2 metrics so it remains the default values. { - Pod: pod2.Pod, + Pod: Pod{Name: "pod2"}, Metrics: Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0, @@ -109,73 +85,30 @@ func TestProvider(t *testing.T) { }, }, }, - { - name: "A new pod added", - initPodMetrics: []*PodMetrics{pod2}, - lister: allPodsLister, - pmc: allPodsMetricsClient, - step: func(p *Provider) { - p.refreshPodsOnce() - _ = p.refreshMetricsOnce() - }, - want: []*PodMetrics{pod1, pod2}, - }, - { - name: "A pod removed", - initPodMetrics: []*PodMetrics{pod1, pod2}, - lister: &testingutil.FakePodLister{ - PodsList: []*corev1.Pod{ - testingutil.MakePod(pod2.Pod.Name).SetReady().SetPodIP("address2").Obj(), - }, - }, - pmc: allPodsMetricsClient, - step: func(p *Provider) { - p.refreshPodsOnce() - _ = p.refreshMetricsOnce() - }, - want: []*PodMetrics{pod2}, - }, - { - name: "A pod removed, another added", - initPodMetrics: []*PodMetrics{pod1}, - lister: &testingutil.FakePodLister{ - PodsList: []*corev1.Pod{ - testingutil.MakePod(pod1.Pod.Name).SetReady().SetPodIP("address1").Obj(), - }, - }, - pmc: allPodsMetricsClient, - step: func(p *Provider) { - p.refreshPodsOnce() - _ = p.refreshMetricsOnce() - }, - want: []*PodMetrics{pod1}, - }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - datastore := NewK8sDataStore(WithPodListerFactory( - func(pool *v1alpha1.InferencePool) *PodLister { - return &PodLister{ - Lister: test.lister, - } - })) - datastore.setInferencePool(&v1alpha1.InferencePool{ - Spec: v1alpha1.InferencePoolSpec{TargetPortNumber: 9009}, - }) - p := NewProvider(test.pmc, datastore) - for _, m := range test.initPodMetrics { - p.UpdatePodMetrics(m.Pod, m) + p := NewProvider(test.pmc, test.datastore) + err := p.Init(time.Millisecond, time.Millisecond) + if test.initErr != (err != nil) { + t.Fatalf("Unexpected error, got: %v, want: %v", err, test.initErr) } - test.step(p) metrics := p.AllPodMetrics() lessFunc := func(a, b *PodMetrics) bool { return a.String() < b.String() } - if diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(lessFunc), - cmpopts.IgnoreFields(PodMetrics{}, "revision")); diff != "" { + if diff := cmp.Diff(test.want, metrics, cmpopts.SortSlices(lessFunc)); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) } } + +func populateMap(pods ...Pod) *sync.Map { + newMap := &sync.Map{} + for _, pod := range pods { + newMap.Store(pod, true) + } + return newMap +} diff --git a/pkg/ext-proc/backend/types.go b/pkg/ext-proc/backend/types.go index d375e4ec2..7e399fedc 100644 --- a/pkg/ext-proc/backend/types.go +++ b/pkg/ext-proc/backend/types.go @@ -28,7 +28,6 @@ type Metrics struct { type PodMetrics struct { Pod Metrics - revision int } func (pm *PodMetrics) String() string { diff --git a/pkg/ext-proc/health.go b/pkg/ext-proc/health.go index 62527d06a..488851eb3 100644 --- a/pkg/ext-proc/health.go +++ b/pkg/ext-proc/health.go @@ -7,7 +7,6 @@ import ( healthPb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" - logutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/logging" klog "k8s.io/klog/v2" ) @@ -20,7 +19,7 @@ func (s *healthServer) Check(ctx context.Context, in *healthPb.HealthCheckReques klog.Infof("gRPC health check not serving: %s", in.String()) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_NOT_SERVING}, nil } - klog.V(logutil.DEBUG).Infof("gRPC health check serving: %s", in.String()) + klog.Infof("gRPC health check serving: %s", in.String()) return &healthPb.HealthCheckResponse{Status: healthPb.HealthCheckResponse_SERVING}, nil } diff --git a/pkg/ext-proc/main.go b/pkg/ext-proc/main.go index 98b7e6cad..a783aa2c5 100644 --- a/pkg/ext-proc/main.go +++ b/pkg/ext-proc/main.go @@ -18,7 +18,6 @@ import ( runserver "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/server" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" - "k8s.io/client-go/kubernetes" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" "k8s.io/component-base/metrics/legacyregistry" @@ -54,6 +53,14 @@ var ( "poolNamespace", runserver.DefaultPoolNamespace, "Namespace of the InferencePool this Endpoint Picker is associated with.") + serviceName = flag.String( + "serviceName", + runserver.DefaultServiceName, + "Name of the Service that will be used to read EndpointSlices from") + zone = flag.String( + "zone", + runserver.DefaultZone, + "The zone that this instance is created in. Will be passed to the corresponding endpointSlice. ") refreshPodsInterval = flag.Duration( "refreshPodsInterval", runserver.DefaultRefreshPodsInterval, @@ -99,6 +106,8 @@ func main() { TargetEndpointKey: *targetEndpointKey, PoolName: *poolName, PoolNamespace: *poolNamespace, + ServiceName: *serviceName, + Zone: *zone, RefreshPodsInterval: *refreshPodsInterval, RefreshMetricsInterval: *refreshMetricsInterval, Scheme: scheme, @@ -107,15 +116,12 @@ func main() { } serverRunner.Setup() - k8sClient, err := kubernetes.NewForConfigAndClient(cfg, serverRunner.Manager.GetHTTPClient()) - if err != nil { - klog.Fatalf("Failed to create client: %v", err) - } - datastore.SetClient(k8sClient) - // Start health and ext-proc servers in goroutines healthSvr := startHealthServer(datastore, *grpcHealthPort) - extProcSvr := serverRunner.Start(&vllm.PodMetricsClientImpl{}) + extProcSvr := serverRunner.Start( + datastore, + &vllm.PodMetricsClientImpl{}, + ) // Start metrics handler metricsSvr := startMetricsHandler(*metricsPort, cfg) @@ -210,5 +216,9 @@ func validateFlags() error { return fmt.Errorf("required %q flag not set", "poolName") } + if *serviceName == "" { + return fmt.Errorf("required %q flag not set", "serviceName") + } + return nil } diff --git a/pkg/ext-proc/scheduling/filter_test.go b/pkg/ext-proc/scheduling/filter_test.go index 34731d152..d88f437c7 100644 --- a/pkg/ext-proc/scheduling/filter_test.go +++ b/pkg/ext-proc/scheduling/filter_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" ) @@ -207,7 +206,7 @@ func TestFilter(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got, cmpopts.IgnoreFields(backend.PodMetrics{}, "revision")); diff != "" { + if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) @@ -401,7 +400,7 @@ func TestFilterFunc(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.output, got, cmpopts.IgnoreFields(backend.PodMetrics{}, "revision")); diff != "" { + if diff := cmp.Diff(test.output, got); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } }) diff --git a/pkg/ext-proc/server/runserver.go b/pkg/ext-proc/server/runserver.go index 981dab114..1c9c1b2e2 100644 --- a/pkg/ext-proc/server/runserver.go +++ b/pkg/ext-proc/server/runserver.go @@ -23,12 +23,14 @@ type ExtProcServerRunner struct { TargetEndpointKey string PoolName string PoolNamespace string + ServiceName string + Zone string RefreshPodsInterval time.Duration RefreshMetricsInterval time.Duration Scheme *runtime.Scheme Config *rest.Config Datastore *backend.K8sDatastore - Manager ctrl.Manager + manager ctrl.Manager } // Default values for CLI flags in main @@ -37,6 +39,8 @@ const ( DefaultTargetEndpointKey = "x-gateway-destination-endpoint" // default for --targetEndpointKey DefaultPoolName = "" // required but no default DefaultPoolNamespace = "default" // default for --poolNamespace + DefaultServiceName = "" // required but no default + DefaultZone = "" // default for --zone DefaultRefreshPodsInterval = 10 * time.Second // default for --refreshPodsInterval DefaultRefreshMetricsInterval = 50 * time.Millisecond // default for --refreshMetricsInterval ) @@ -47,20 +51,22 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { TargetEndpointKey: DefaultTargetEndpointKey, PoolName: DefaultPoolName, PoolNamespace: DefaultPoolNamespace, + ServiceName: DefaultServiceName, + Zone: DefaultZone, RefreshPodsInterval: DefaultRefreshPodsInterval, RefreshMetricsInterval: DefaultRefreshMetricsInterval, // Scheme, Config, and Datastore can be assigned later. } } -// Setup creates the reconcilers for pools and models and starts the manager. +// Setup creates the reconcilers for pools, models, and endpointSlices and starts the manager. func (r *ExtProcServerRunner) Setup() { // Create a new manager to manage controllers mgr, err := ctrl.NewManager(r.Config, ctrl.Options{Scheme: r.Scheme}) if err != nil { klog.Fatalf("Failed to create controller manager: %v", err) } - r.Manager = mgr + r.manager = mgr // Create the controllers and register them with the manager if err := (&backend.InferencePoolReconciler{ @@ -88,10 +94,22 @@ func (r *ExtProcServerRunner) Setup() { }).SetupWithManager(mgr); err != nil { klog.Fatalf("Failed setting up InferenceModelReconciler: %v", err) } + + if err := (&backend.EndpointSliceReconciler{ + Datastore: r.Datastore, + Scheme: mgr.GetScheme(), + Client: mgr.GetClient(), + Record: mgr.GetEventRecorderFor("endpointslice"), + ServiceName: r.ServiceName, + Zone: r.Zone, + }).SetupWithManager(mgr); err != nil { + klog.Fatalf("Failed setting up EndpointSliceReconciler: %v", err) + } } // Start starts the Envoy external processor server in a goroutine. func (r *ExtProcServerRunner) Start( + podDatastore *backend.K8sDatastore, podMetricsClient backend.PodMetricsClient, ) *grpc.Server { svr := grpc.NewServer() @@ -104,7 +122,7 @@ func (r *ExtProcServerRunner) Start( klog.Infof("Ext-proc server listening on port: %d", r.GrpcPort) // Initialize backend provider - pp := backend.NewProvider(podMetricsClient, r.Datastore) + pp := backend.NewProvider(podMetricsClient, podDatastore) if err := pp.Init(r.RefreshPodsInterval, r.RefreshMetricsInterval); err != nil { klog.Fatalf("Failed to initialize backend provider: %v", err) } @@ -125,12 +143,13 @@ func (r *ExtProcServerRunner) Start( } func (r *ExtProcServerRunner) StartManager() { - if r.Manager == nil { + if r.manager == nil { klog.Fatalf("Runner has no manager setup to run: %v", r) } // Start the controller manager. Blocking and will return when shutdown is complete. klog.Infof("Starting controller manager") - if err := r.Manager.Start(ctrl.SetupSignalHandler()); err != nil { + mgr := r.manager + if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil { klog.Fatalf("Error starting controller manager: %v", err) } klog.Info("Controller manager shutting down") diff --git a/pkg/ext-proc/test/utils.go b/pkg/ext-proc/test/utils.go index a9dc4efa0..63972849e 100644 --- a/pkg/ext-proc/test/utils.go +++ b/pkg/ext-proc/test/utils.go @@ -18,13 +18,13 @@ import ( func StartExtProc(port int, refreshPodsInterval, refreshMetricsInterval time.Duration, pods []*backend.PodMetrics, models map[string]*v1alpha1.InferenceModel) *grpc.Server { ps := make(backend.PodSet) - pms := make(map[string]*backend.PodMetrics) + pms := make(map[backend.Pod]*backend.PodMetrics) for _, pod := range pods { ps[pod.Pod] = true - pms[pod.Pod.Name] = pod + pms[pod.Pod] = pod } pmc := &backend.FakePodMetricsClient{Res: pms} - pp := backend.NewProvider(pmc, backend.NewK8sDataStore()) + pp := backend.NewProvider(pmc, backend.NewK8sDataStore(backend.WithPods(pods))) if err := pp.Init(refreshPodsInterval, refreshMetricsInterval); err != nil { klog.Fatalf("failed to initialize: %v", err) } diff --git a/pkg/ext-proc/util/testing/lister.go b/pkg/ext-proc/util/testing/lister.go deleted file mode 100644 index 023f30a1d..000000000 --- a/pkg/ext-proc/util/testing/lister.go +++ /dev/null @@ -1,19 +0,0 @@ -package testing - -import ( - v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/labels" - listersv1 "k8s.io/client-go/listers/core/v1" -) - -type FakePodLister struct { - PodsList []*v1.Pod -} - -func (l *FakePodLister) List(selector labels.Selector) (ret []*v1.Pod, err error) { - return l.PodsList, nil -} - -func (l *FakePodLister) Pods(namespace string) listersv1.PodNamespaceLister { - panic("not implemented") -} diff --git a/pkg/ext-proc/util/testing/wrappers.go b/pkg/ext-proc/util/testing/wrappers.go deleted file mode 100644 index 7b593bbd9..000000000 --- a/pkg/ext-proc/util/testing/wrappers.go +++ /dev/null @@ -1,38 +0,0 @@ -package testing - -import ( - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) - -// PodWrapper wraps a Pod inside. -type PodWrapper struct{ corev1.Pod } - -// MakePod creates a Pod wrapper. -func MakePod(name string) *PodWrapper { - return &PodWrapper{ - corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - }, - }, - } -} - -// Obj returns the inner Pod. -func (p *PodWrapper) Obj() *corev1.Pod { - return &p.Pod -} - -func (p *PodWrapper) SetReady() *PodWrapper { - p.Status.Conditions = []corev1.PodCondition{{ - Type: corev1.PodReady, - Status: corev1.ConditionTrue, - }} - return p -} - -func (p *PodWrapper) SetPodIP(podIP string) *PodWrapper { - p.Status.PodIP = podIP - return p -} diff --git a/pkg/manifests/ext_proc.yaml b/pkg/manifests/ext_proc.yaml index b9b860dc9..4e82779ef 100644 --- a/pkg/manifests/ext_proc.yaml +++ b/pkg/manifests/ext_proc.yaml @@ -77,6 +77,8 @@ spec: - "vllm-llama2-7b-pool" - -v - "3" + - -serviceName + - "vllm-llama2-7b-pool" - -grpcPort - "9002" - -grpcHealthPort diff --git a/pkg/manifests/vllm/deployment.yaml b/pkg/manifests/vllm/deployment.yaml index 1f5073e98..4af0891d7 100644 --- a/pkg/manifests/vllm/deployment.yaml +++ b/pkg/manifests/vllm/deployment.yaml @@ -1,3 +1,16 @@ +apiVersion: v1 +kind: Service +metadata: + name: vllm-llama2-7b-pool +spec: + selector: + app: vllm-llama2-7b-pool + ports: + - protocol: TCP + port: 8000 + targetPort: 8000 + type: ClusterIP +--- apiVersion: apps/v1 kind: Deployment metadata: diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index 019e858a2..c2c1ea928 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -245,6 +245,11 @@ func createModelServer(k8sClient client.Client, secretPath, deployPath string) { // Wait for the deployment to be available. testutils.DeploymentAvailable(ctx, k8sClient, deploy, modelReadyTimeout, interval) + + // Wait for the service to exist. + testutils.EventuallyExists(ctx, func() error { + return k8sClient.Get(ctx, types.NamespacedName{Namespace: nsName, Name: modelServerName}, &corev1.Service{}) + }, existsTimeout, interval) } // createEnvoy creates the envoy proxy resources used for testing from the given filePath. diff --git a/test/integration/hermetic_test.go b/test/integration/hermetic_test.go index 3dfe28f7c..95ad49081 100644 --- a/test/integration/hermetic_test.go +++ b/test/integration/hermetic_test.go @@ -12,7 +12,6 @@ import ( "log" "os" "path/filepath" - "strconv" "testing" "time" @@ -27,8 +26,6 @@ import ( "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/backend" runserver "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/server" extprocutils "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/test" - testingutil "inference.networking.x-k8s.io/gateway-api-inference-extension/pkg/ext-proc/util/testing" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" k8syaml "k8s.io/apimachinery/pkg/util/yaml" @@ -116,7 +113,7 @@ func SKIPTestHandleRequestBody(t *testing.T) { { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("pod-1:8000"), + RawValue: []byte("address-1"), }, }, { @@ -182,7 +179,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { { Header: &configPb.HeaderValue{ Key: runserver.DefaultTargetEndpointKey, - RawValue: []byte("pod-1:8000"), + RawValue: []byte("address-1"), }, }, { @@ -196,7 +193,7 @@ func TestKubeInferenceModelRequest(t *testing.T) { Fields: map[string]*structpb.Value{ runserver.DefaultTargetEndpointKey: { Kind: &structpb.Value_StringValue{ - StringValue: "pod-1:8000", + StringValue: "address-1", }, }, }, @@ -206,38 +203,47 @@ func TestKubeInferenceModelRequest(t *testing.T) { }, } - metrics := []*backend.Metrics{ + pods := []*backend.PodMetrics{ { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + Pod: extprocutils.FakePod(0), + Metrics: backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, }, }, { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.1, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg2": 1, + Pod: extprocutils.FakePod(1), + Metrics: backend.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.1, + ActiveModels: map[string]int{ + "foo": 1, + "sql-lora-1fdg2": 1, + }, }, }, { - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, + Pod: extprocutils.FakePod(2), + Metrics: backend.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + }, }, }, } // Set up global k8sclient and extproc server runner with test environment config - podMetrics := BeforeSuit(metrics) + BeforeSuit() for _, test := range tests { t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer(t, podMetrics) + client, cleanup := setUpHermeticServer(t, pods) t.Cleanup(cleanup) want := &extProcPb.ProcessingResponse{ Response: &extProcPb.ProcessingResponse_RequestBody{ @@ -318,8 +324,8 @@ func setUpHermeticServer(t *testing.T, pods []*backend.PodMetrics) (client extPr } } } - inferencePool := &v1alpha1.InferencePool{} for _, doc := range docs { + inferencePool := &v1alpha1.InferencePool{} if err = yaml.Unmarshal(doc, inferencePool); err != nil { log.Fatalf("Can't unmarshal object: %v", doc) } @@ -328,19 +334,18 @@ func setUpHermeticServer(t *testing.T, pods []*backend.PodMetrics) (client extPr if err := k8sClient.Create(context.Background(), inferencePool); err != nil { log.Fatalf("unable to create inferencePool %v: %v", inferencePool.Name, err) } - // expecting a single inferencepool - break } } ps := make(backend.PodSet) - pms := make(map[string]*backend.PodMetrics) + pms := make(map[backend.Pod]*backend.PodMetrics) for _, pod := range pods { ps[pod.Pod] = true - pms[pod.Pod.Name] = pod + pms[pod.Pod] = pod } pmc := &backend.FakePodMetricsClient{Res: pms} - server := serverRunner.Start(pmc) + + server := serverRunner.Start(backend.NewK8sDataStore(backend.WithPods(pods)), pmc) if err != nil { log.Fatalf("Ext-proc failed with the err: %v", err) } @@ -368,7 +373,7 @@ func setUpHermeticServer(t *testing.T, pods []*backend.PodMetrics) (client extPr } // Sets up a test environment and returns the runner struct -func BeforeSuit(metrics []*backend.Metrics) []*backend.PodMetrics { +func BeforeSuit() { // Set up mock k8s API Client testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, @@ -390,35 +395,12 @@ func BeforeSuit(metrics []*backend.Metrics) []*backend.PodMetrics { log.Fatalf("No error, but returned kubernetes client is nil, cfg: %v", cfg) } - podMetrics := []*backend.PodMetrics{} - fakeLister := &testingutil.FakePodLister{ - PodsList: []*corev1.Pod{}, - } - for i, m := range metrics { - podName := "pod-" + strconv.Itoa(i) - pod := testingutil.MakePod(podName).SetReady().SetPodIP(podName).Obj() - fakeLister.PodsList = append(fakeLister.PodsList, pod) - podMetrics = append(podMetrics, &backend.PodMetrics{ - Pod: backend.Pod{ - Name: pod.Name, - Address: pod.Status.PodIP + ":8000", - }, - Metrics: *m, - }) - } - serverRunner = runserver.NewDefaultExtProcServerRunner() // Adjust from defaults serverRunner.PoolName = "vllm-llama2-7b-pool" serverRunner.Scheme = scheme serverRunner.Config = cfg - serverRunner.Datastore = backend.NewK8sDataStore(backend.WithPodListerFactory( - func(pool *v1alpha1.InferencePool) *backend.PodLister { - klog.V(1).Infof("Setting the fake lister %v", len(fakeLister.PodsList)) - return &backend.PodLister{ - Lister: fakeLister, - } - })) + serverRunner.Datastore = backend.NewK8sDataStore() serverRunner.Setup() @@ -426,10 +408,6 @@ func BeforeSuit(metrics []*backend.Metrics) []*backend.PodMetrics { go func() { serverRunner.StartManager() }() - - // Wait the reconcilers to populate the datastore. - time.Sleep(5 * time.Second) - return podMetrics } func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) {