diff --git a/pkg/epp/controller/inferencemodel_reconciler_test.go b/pkg/epp/controller/inferencemodel_reconciler_test.go index 57dc2469..80c30e19 100644 --- a/pkg/epp/controller/inferencemodel_reconciler_test.go +++ b/pkg/epp/controller/inferencemodel_reconciler_test.go @@ -25,6 +25,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -178,6 +179,7 @@ func TestInferenceModelReconciler(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Create a fake client with no InferenceModel objects. scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) _ = v1alpha2.Install(scheme) initObjs := []client.Object{} if test.model != nil { @@ -186,6 +188,7 @@ func TestInferenceModelReconciler(t *testing.T) { for _, m := range test.modelsInAPIServer { initObjs = append(initObjs, m) } + fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(initObjs...). @@ -196,7 +199,7 @@ func TestInferenceModelReconciler(t *testing.T) { for _, m := range test.modelsInStore { ds.ModelSetIfOlder(m) } - ds.PoolSet(pool) + _ = ds.PoolSet(context.Background(), fakeClient, pool) reconciler := &InferenceModelReconciler{ Client: fakeClient, Record: record.NewFakeRecorder(10), diff --git a/pkg/epp/controller/inferencepool_reconciler.go b/pkg/epp/controller/inferencepool_reconciler.go index 0738181f..fb7d7727 100644 --- a/pkg/epp/controller/inferencepool_reconciler.go +++ b/pkg/epp/controller/inferencepool_reconciler.go @@ -18,7 +18,6 @@ package controller import ( "context" - "reflect" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/client-go/tools/record" @@ -60,28 +59,15 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques c.Datastore.Clear() return ctrl.Result{}, nil } - - c.updateDatastore(ctx, infPool) + // update pool in datastore + if err := c.Datastore.PoolSet(ctx, c.Client, infPool); err != nil { + logger.Error(err, "Failed to update datastore") + return ctrl.Result{}, err + } return ctrl.Result{}, nil } -func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool *v1alpha2.InferencePool) { - logger := log.FromContext(ctx) - oldPool, err := c.Datastore.PoolGet() - c.Datastore.PoolSet(newPool) - if err != nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { - logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", newPool.Spec.Selector) - // A full resync is required to address two cases: - // 1) At startup, the pod events may get processed before the pool is synced with the datastore, - // and hence they will not be added to the store since pool selector is not known yet - // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need - // to resync the whole pool: remove pods in the store that don't match the new selector and add - // the ones that may have existed already to the store. - c.Datastore.PodResyncAll(ctx, c.Client, newPool) - } -} - func (c *InferencePoolReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). For(&v1alpha2.InferencePool{}). diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 494adeb7..6d1af8d9 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -27,7 +27,6 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" @@ -41,8 +40,7 @@ type PodReconciler struct { func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { logger := log.FromContext(ctx) - pool, err := c.Datastore.PoolGet() - if err != nil { + if !c.Datastore.PoolHasSynced() { logger.V(logutil.TRACE).Info("Skipping reconciling Pod because the InferencePool is not available yet") // When the inferencePool is initialized it lists the appropriate pods and populates the datastore, so no need to requeue. return ctrl.Result{}, nil @@ -60,7 +58,7 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R return ctrl.Result{}, err } - c.updateDatastore(logger, pod, pool) + c.updateDatastore(logger, pod) return ctrl.Result{}, nil } @@ -70,13 +68,13 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(c) } -func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod, pool *v1alpha2.InferencePool) { +func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} - if !pod.DeletionTimestamp.IsZero() || !c.Datastore.PoolLabelsMatch(pod.Labels) || !podutil.IsPodReady(pod) { + if !podutil.IsPodReady(pod) || !c.Datastore.PoolLabelsMatch(pod.Labels) { logger.V(logutil.DEBUG).Info("Pod removed or not added", "name", namespacedName) c.Datastore.PodDelete(namespacedName) } else { - if c.Datastore.PodUpdateOrAddIfNotExist(pod, pool) { + if c.Datastore.PodUpdateOrAddIfNotExist(pod) { logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) } else { logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index e4cb0b62..d2bdd5d0 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -182,9 +182,9 @@ func TestPodReconciler(t *testing.T) { // Configure the initial state of the datastore. store := datastore.NewDatastore(t.Context(), pmf) - store.PoolSet(test.pool) + _ = store.PoolSet(t.Context(), fakeClient, test.pool) for _, pod := range test.existingPods { - store.PodUpdateOrAddIfNotExist(pod, pool) + store.PodUpdateOrAddIfNotExist(pod) } podReconciler := &PodReconciler{Client: fakeClient, Datastore: store} diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 5435e3af..f8378d25 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "reflect" "sync" corev1 "k8s.io/api/core/v1" @@ -44,7 +45,10 @@ var ( // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) type Datastore interface { // InferencePool operations - PoolSet(pool *v1alpha2.InferencePool) + // PoolSet sets the given pool in datastore. If the given pool has different label selector than the previous pool + // that was stored, the function triggers a resync of the pods to keep the datastore updated. If the given pool + // is nil, this call triggers the datastore.Clear() function. + PoolSet(ctx context.Context, client client.Client, pool *v1alpha2.InferencePool) error PoolGet() (*v1alpha2.InferencePool, error) PoolHasSynced() bool PoolLabelsMatch(podLabels map[string]string) bool @@ -60,10 +64,9 @@ type Datastore interface { // PodGetAll returns all pods and metrics, including fresh and stale. PodGetAll() []backendmetrics.PodMetrics // PodList lists pods matching the given predicate. - PodList(func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics - PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool + PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool PodDelete(namespacedName types.NamespacedName) - PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) // Clears the store state, happens when the pool gets deleted. Clear() @@ -102,10 +105,31 @@ func (ds *datastore) Clear() { } // /// InferencePool APIs /// -func (ds *datastore) PoolSet(pool *v1alpha2.InferencePool) { +func (ds *datastore) PoolSet(ctx context.Context, client client.Client, pool *v1alpha2.InferencePool) error { + if pool == nil { + ds.Clear() + return nil + } + logger := log.FromContext(ctx) ds.poolAndModelsMu.Lock() defer ds.poolAndModelsMu.Unlock() + + oldPool := ds.pool ds.pool = pool + if oldPool == nil || !reflect.DeepEqual(pool.Spec.Selector, oldPool.Spec.Selector) { + logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", pool.Spec.Selector) + // A full resync is required to address two cases: + // 1) At startup, the pod events may get processed before the pool is synced with the datastore, + // and hence they will not be added to the store since pool selector is not known yet + // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need + // to resync the whole pool: remove pods in the store that don't match the new selector and add + // the ones that may have existed already to the store. + if err := ds.podResyncAll(ctx, client); err != nil { + return fmt.Errorf("failed to update pods according to the pool selector - %w", err) + } + } + + return nil } func (ds *datastore) PoolGet() (*v1alpha2.InferencePool, error) { @@ -229,7 +253,7 @@ func (ds *datastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []b return res } -func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool { +func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { namespacedName := types.NamespacedName{ Name: pod.Name, Namespace: pod.Namespace, @@ -247,27 +271,35 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.In return ok } -func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) { +func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { + v, ok := ds.pods.LoadAndDelete(namespacedName) + if ok { + pmr := v.(backendmetrics.PodMetrics) + pmr.StopRefreshLoop() + } +} + +func (ds *datastore) podResyncAll(ctx context.Context, ctrlClient client.Client) error { logger := log.FromContext(ctx) podList := &corev1.PodList{} if err := ctrlClient.List(ctx, podList, &client.ListOptions{ - LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector), - Namespace: pool.Namespace, + LabelSelector: selectorFromInferencePoolSelector(ds.pool.Spec.Selector), + Namespace: ds.pool.Namespace, }); err != nil { - log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients") - return + return fmt.Errorf("failed to list pods - %w", err) } activePods := make(map[string]bool) for _, pod := range podList.Items { - if podutil.IsPodReady(&pod) { - namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} - activePods[pod.Name] = true - if ds.PodUpdateOrAddIfNotExist(&pod, pool) { - logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) - } else { - logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) - } + if !podutil.IsPodReady(&pod) { + continue + } + namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + activePods[pod.Name] = true + if ds.PodUpdateOrAddIfNotExist(&pod) { + logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) + } else { + logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) } } @@ -281,14 +313,8 @@ func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, return true } ds.pods.Range(deleteFn) -} -func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { - v, ok := ds.pods.LoadAndDelete(namespacedName) - if ok { - pmr := v.(backendmetrics.PodMetrics) - pmr.StopRefreshLoop() - } + return nil } func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector { diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index abbff429..e8c77d37 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -27,7 +27,10 @@ import ( "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" @@ -71,9 +74,15 @@ func TestPool(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) datastore := NewDatastore(context.Background(), pmf) - datastore.PoolSet(tt.inferencePool) + _ = datastore.PoolSet(context.Background(), fakeClient, tt.inferencePool) gotPool, gotErr := datastore.PoolGet() if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { t.Errorf("Unexpected error diff (+got/-want): %s", diff) @@ -320,11 +329,17 @@ func TestMetrics(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond) ds := NewDatastore(ctx, pmf) - ds.PoolSet(inferencePool) + _ = ds.PoolSet(ctx, fakeClient, inferencePool) for _, pod := range test.storePods { - ds.PodUpdateOrAddIfNotExist(pod, inferencePool) + ds.PodUpdateOrAddIfNotExist(pod) } assert.EventuallyWithT(t, func(t *assert.CollectT) { got := ds.PodGetAll() diff --git a/pkg/epp/util/pod/pod.go b/pkg/epp/util/pod/pod.go index 9f564024..4fcb948f 100644 --- a/pkg/epp/util/pod/pod.go +++ b/pkg/epp/util/pod/pod.go @@ -21,6 +21,9 @@ import ( ) func IsPodReady(pod *corev1.Pod) bool { + if !pod.DeletionTimestamp.IsZero() { + return false + } for _, condition := range pod.Status.Conditions { if condition.Type == corev1.PodReady { if condition.Status == corev1.ConditionTrue {