From bf593dd73b1100bb8399749aeda06d8a5853e3ff Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Sun, 4 May 2025 16:45:27 +0300 Subject: [PATCH 01/18] Add P/D scheduler - use 2 schedulers in it, one for prefill and one for decode. P/D scheduler is enabled by environment variable value, list of scorers and their weight are defined by environment variables + delete pd-filter --- pkg/epp/scheduling/config_utils.go | 84 ++++++++++++++++++ pkg/epp/scheduling/local_config.go | 15 ---- pkg/epp/scheduling/pd_config.go | 78 +++++++++++++++++ pkg/epp/scheduling/pd_scheduler.go | 86 +++++++++++++++++++ .../scheduling/plugins/filter/pd_filter.go | 67 ++++++--------- pkg/epp/scheduling/scheduler.go | 6 ++ pkg/epp/server/runserver.go | 9 +- 7 files changed, 288 insertions(+), 57 deletions(-) create mode 100644 pkg/epp/scheduling/config_utils.go create mode 100644 pkg/epp/scheduling/pd_config.go create mode 100644 pkg/epp/scheduling/pd_scheduler.go diff --git a/pkg/epp/scheduling/config_utils.go b/pkg/epp/scheduling/config_utils.go new file mode 100644 index 000000000..4145dbe1b --- /dev/null +++ b/pkg/epp/scheduling/config_utils.go @@ -0,0 +1,84 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "fmt" + + "github.com/go-logr/logr" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" +) + +const ( + prefillKvCacheScorerEnablementEnvVar = "PREFILL_ENABLE_KVCACHE_AWARE_SCORER" + prefillLoadAwareScorerEnablementEnvVar = "PREFILL_ENABLE_LOAD_AWARE_SCORER" + decodeKvCacheScorerEnablementEnvVar = "DECODE_ENABLE_KVCACHE_AWARE_SCORER" + decodeLoadAwareScorerEnablementEnvVar = "DECODE_ENABLE_LOAD_AWARE_SCORER" + + prefillKvCacheScorerWeightEnvVar = "PREFILL_KVCACHE_AWARE_SCORER_WEIGHT" + prefillLoadAwareScorerWeightEnvVar = "PREFILL_LOAD_AWARE_SCORER_WEIGHT" + decodeKvCacheScorerWeightEnvVar = "DECODE_KVCACHE_AWARE_SCORER_WEIGHT" + decodeLoadAwareScorerWeightEnvVar = "DECODE_LOAD_AWARE_SCORER_WEIGHT" + + pdEnabledEnvKey = "PD_ENABLED" + + pdPromptLenThresholdEnvKey = "PD_PROMPT_LEN_THRESHOLD" + pdPromptLenThresholdDefault = 10 +) + +const ( + loadAwareScorerName = "LoadAwareScorer" + kvCacheAwareScorerName = "KVCacheAwareScorer" +) + +func addScorerByEnvironment(ctx context.Context, config *SchedulerConfig, scorerName string, scorerEnabledEnvKey string, weightEnvKey string, logger logr.Logger) { + if envutil.GetEnvString(scorerEnabledEnvKey, "false", logger) != "true" { + logger.Info(fmt.Sprintf("Skipping %s creation as it is not enabled", scorerName)) + return + } + + weight := envutil.GetEnvInt(weightEnvKey, 1, logger) + scorer, err := createScorerByName(ctx, scorerName) + if err != nil { + logger.Error(err, "Failed to create scorrer") + return + } + + defaultConfig.scorers[scorer] = weight + logger.Info("Initialized scorer", "scorer", scorerName, "weight", weight) +} + +func createScorerByName(ctx context.Context, name string) (plugins.Scorer, error) { + switch name { + case loadAwareScorerName: + return &scorer.LoadAwareScorer{}, nil + case kvCacheAwareScorerName: + return scorer.NewKVCacheAwareScorer(ctx) + } + return nil, fmt.Errorf("invalid scorer type %s", name) +} + +func getPDEnabledFromEnvironment(logger logr.Logger) bool { + return envutil.GetEnvString(pdEnabledEnvKey, "false", logger) == "true" +} + +func getPDPromptLenThresholdFromEnvironment(logger logr.Logger) int { + return envutil.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger) +} diff --git a/pkg/epp/scheduling/local_config.go b/pkg/epp/scheduling/local_config.go index 85b91d7cd..78303d814 100644 --- a/pkg/epp/scheduling/local_config.go +++ b/pkg/epp/scheduling/local_config.go @@ -20,7 +20,6 @@ import ( "context" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" @@ -41,7 +40,6 @@ func setDefaultConfig() { // this configuration is a temporary state, it should be better streamlined. setLoadAwareScorer() setKVCacheAwareScorer() - setPDFilter() defaultConfig.picker = picker.NewMaxScorePicker() } @@ -79,16 +77,3 @@ func setKVCacheAwareScorer() { defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight) } - -func setPDFilter() { - ctx := context.Background() - loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) - - if envutil.GetEnvString(pdFilterEnablementEnvVar, "false", loggerDebug) != "true" { - loggerDebug.Info("Skipping PDFilter creation as it is not enabled") - return - } - - defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter) - loggerDebug.Info("Initialized PDFilter") -} diff --git a/pkg/epp/scheduling/pd_config.go b/pkg/epp/scheduling/pd_config.go new file mode 100644 index 000000000..0fc2283c7 --- /dev/null +++ b/pkg/epp/scheduling/pd_config.go @@ -0,0 +1,78 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +var prefillConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{filter.PrefillFilter}, + scorers: map[plugins.Scorer]int{}, + picker: picker.NewMaxScorePicker(), + postSchedulePlugins: []plugins.PostSchedule{}, +} +var decodeConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{filter.DecodeFilter}, + scorers: map[plugins.Scorer]int{}, + picker: picker.NewMaxScorePicker(), + postSchedulePlugins: []plugins.PostSchedule{}, +} + +var IsPDEnabled = false +var PromptLengthThreshold int + +func init() { + ctx := context.Background() + loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG) + + loadPrefillConfiguration(ctx, loggerDebug) + loadDecodeConfiguration(ctx, loggerDebug) + + // set IsPDEnabled by environment + IsPDEnabled = getPDEnabledFromEnvironment(loggerDebug) + PromptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug) +} + +func loadPrefillConfiguration(ctx context.Context, logger logr.Logger) { + // add scorers + addScorerByEnvironment(ctx, prefillConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger) + addScorerByEnvironment(ctx, prefillConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger) + + // set filter + // TODO - do we want to keep default filters? + prefillConfig.filters = []plugins.Filter{filter.PrefillFilter} +} + +func loadDecodeConfiguration(ctx context.Context, logger logr.Logger) { + // add scorers + addScorerByEnvironment(ctx, decodeConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger) + addScorerByEnvironment(ctx, decodeConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger) + + // set filter + // TODO - do we want to keep default filters? + decodeConfig.filters = []plugins.Filter{filter.DecodeFilter} +} diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go new file mode 100644 index 000000000..a3b3c73b8 --- /dev/null +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -0,0 +1,86 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package scheduling implements request scheduling algorithms. +package scheduling + +import ( + "context" + "fmt" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + prefillPodHeader = "x-prefiller-url" +) + +func NewPDScheduler(datastore Datastore) *PDScheduler { + return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig) +} + +func NewPDSchedulerWithConfig(datastore Datastore, prefillConfig *SchedulerConfig, decodeConfig *SchedulerConfig) *PDScheduler { + return &PDScheduler{ + datastore: datastore, + prefillScheduler: NewSchedulerWithConfig(datastore, prefillConfig), + decodeScheduler: NewSchedulerWithConfig(datastore, decodeConfig), + } +} + +type PDScheduler struct { + datastore Datastore + prefillScheduler *Scheduler + decodeScheduler *Scheduler +} + +// Schedule finds the target pod based on metrics and the requested lora adapter. +func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { + logger := log.FromContext(ctx).WithValues("pd-schedule", req) + loggerDebug := logger.V(logutil.DEBUG) + + if len(req.Prompt) < PromptLengthThreshold { + // prompt is short enough - use decode scheduling logic + return s.decodeScheduler.Schedule(ctx, req) + } + + pool, err := s.datastore.PoolGet() + if err != nil { + return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + } + + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber) + + // prompt requires processing on two pods - prefill and decode + // start with calculating of the prefill pod + res, err := s.prefillScheduler.scheduleWithContext(ctx, sCtx, req, logger) + if err != nil { + return nil, err + } + + if res.TargetPod != nil { + url := fmt.Sprintf("http://%s:%d", res.TargetPod.GetPod().Address, sCtx.TargetPort) + sCtx.MutatedHeaders[prefillPodHeader] = url + } + + // get decode pod + return s.decodeScheduler.scheduleWithContext(ctx, sCtx, req, logger) +} diff --git a/pkg/epp/scheduling/plugins/filter/pd_filter.go b/pkg/epp/scheduling/plugins/filter/pd_filter.go index 228d18143..a70c864d5 100644 --- a/pkg/epp/scheduling/plugins/filter/pd_filter.go +++ b/pkg/epp/scheduling/plugins/filter/pd_filter.go @@ -16,61 +16,46 @@ limitations under the License. package filter import ( - "fmt" - "math/rand/v2" - - "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -const ( - prefillPodHeader = "x-prefiller-url" ) -var PDFilter = &baseFilter{ - name: "p/d filter", - filter: prefillDecodeFilterFunc, +// //////////////////////////// +// Prefill filter +var PrefillFilter = &baseFilter{ + name: "prefill filter", + filter: prefillFilterFunc, } -// prefillDecodeFilterFunc implements a pod selection strategy that filters out pods, -// which role is 'prefill', in addition a header with selected prefill pod is added -// -// Initial implementation: -// 1 - select one random pod marked as 'prefill' and add it name to header -// 2 - return a random pod that marked as "decode" or "both" -// -// Returns: -// - Filtered slice of pod metrics, could contain one or zerro elements -func prefillDecodeFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { - loggerDebug := log.FromContext(ctx).WithName("pd_filter").V(logutil.DEBUG) - - pPods := make([]types.Pod, 0) - dPods := make([]types.Pod, 0) +// prefillFilterFunc filters out all pods that are not marked as "prefill" +func prefillFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filteredPods := make([]types.Pod, 0) for _, pod := range pods { if pod.GetPod().Role == metrics.Prefill { - pPods = append(pPods, pod) - } else if pod.GetPod().Role == metrics.Decode || pod.GetPod().Role == metrics.Both { - dPods = append(dPods, pod) + filteredPods = append(filteredPods, pod) } } - if len(pPods) > 0 { - // select a random prefill pod - randomIndex := rand.IntN(len(pPods)) - url := fmt.Sprintf("http://%s:%d", pPods[randomIndex].GetPod().Address, ctx.TargetPort) - loggerDebug.Info("Prefill pod selected", "url", url) + return filteredPods +} + +// //////////////////////////// +// Decode filter +var DecodeFilter = &baseFilter{ + name: "decode filter", + filter: decodeFilterFunc, +} - ctx.MutatedHeaders[prefillPodHeader] = url - } +// decodeFilterFunc filters out all pods that are not marked as "decode" or "both" +func decodeFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filteredPods := make([]types.Pod, 0) - if len(dPods) > 1 { - // leave only one pod - randomIndex := rand.IntN(len(dPods)) - return []types.Pod{dPods[randomIndex]} + for _, pod := range pods { + if pod.GetPod().Role == metrics.Decode || pod.GetPod().Role == metrics.Both { + filteredPods = append(filteredPods, pod) + } } - return dPods + return filteredPods } diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index f4e1714d4..e1774fdac 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -112,6 +113,11 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber) + + return s.scheduleWithContext(ctx, sCtx, req, loggerDebug) +} + +func (s *Scheduler) scheduleWithContext(ctx context.Context, sCtx *types.SchedulingContext, req *types.LLMRequest, loggerDebug logr.Logger) (*types.Result, error) { loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot)) s.runPreSchedulePlugins(sCtx) diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index 0c0a6a6dc..dbe63627d 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -137,7 +137,14 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { } else { srv = grpc.NewServer() } - extProcServer := handlers.NewStreamingServer(scheduling.NewScheduler(r.Datastore), r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) + + var scheduler handlers.Scheduler + if scheduling.IsPDEnabled { + scheduler = scheduling.NewPDScheduler(r.Datastore) + } else { + scheduler = scheduling.NewScheduler(r.Datastore) + } + extProcServer := handlers.NewStreamingServer(scheduler, r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) extProcPb.RegisterExternalProcessorServer( srv, extProcServer, From a694eb7b11fc9a087e63caf332174fb1dba34f33 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Sun, 4 May 2025 16:53:23 +0300 Subject: [PATCH 02/18] Remove unused variable --- pkg/epp/scheduling/pd_scheduler.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go index a3b3c73b8..56717beac 100644 --- a/pkg/epp/scheduling/pd_scheduler.go +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -24,7 +24,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) const ( @@ -52,7 +51,6 @@ type PDScheduler struct { // Schedule finds the target pod based on metrics and the requested lora adapter. func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("pd-schedule", req) - loggerDebug := logger.V(logutil.DEBUG) if len(req.Prompt) < PromptLengthThreshold { // prompt is short enough - use decode scheduling logic From 3ac82b2fed627755d148c6009637d00a52b40266 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Sun, 4 May 2025 16:54:02 +0300 Subject: [PATCH 03/18] Update readme file with envirnment variables relevant to P/D scheduler --- README.md | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dd262dcfc..efdd5ceef 100644 --- a/README.md +++ b/README.md @@ -23,9 +23,42 @@ export ENABLE_LOAD_AWARE_SCORER=true export LOAD_AWARE_SCORER_WEIGHT=1.0 ``` -To enable PDFilter, the following env var must be configured: +To enable PD Scheduler, the following env var must be configured: ``` -export ENABLE_PD_FILTER=true +export PD_ENABLED=true +``` + +To define prompt length threshold (requests with length is longer than the value defined here will be processed using prefill-decode process), the following env var must be configured: +``` +export PD_PROMPT_LEN_THRESHOLD=10 +``` + +Prefill scheduler configuration: + +To enable and configure kv cache scorer, the following env vars must be configured: +``` +export PREFILL_ENABLE_KVCACHE_AWARE_SCORER=true +export PREFILL_KVCACHE_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable and configure load aware scorer, the following env vars must be configured: +``` +export PREFILL_ENABLE_LOAD_AWARE_SCORER=true +export PREFILL_LOAD_AWARE_SCORER_WEIGHT=1.0 +``` + +Decode scheduler configuration: + +To enable and configure kv cache scorer, the following env vars must be configured: +``` +export DECODE_ENABLE_KVCACHE_AWARE_SCORER=true +export DECODE_KVCACHE_AWARE_SCORER_WEIGHT=1.0 +``` + +To enable and configure load aware scorer, the following env vars must be configured: +``` +export DECODE_ENABLE_LOAD_AWARE_SCORER=true +export DECODE_LOAD_AWARE_SCORER_WEIGHT=1.0 ``` --- [Inference Gateways]:#concepts-and-definitions From f43d9c7b236d18864a599214e75e207f73cee9ee Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Sun, 4 May 2025 17:45:04 +0300 Subject: [PATCH 04/18] Fix problem caused by merge --- pkg/epp/scheduling/pd_scheduler.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go index 56717beac..89770a292 100644 --- a/pkg/epp/scheduling/pd_scheduler.go +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -82,3 +82,7 @@ func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*typ // get decode pod return s.decodeScheduler.scheduleWithContext(ctx, sCtx, req, logger) } + +func (s *PDScheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { + return s.decodeScheduler.RunPostResponsePlugins(ctx, req, targetPodName) +} From 9948464c6754e68d94d55861151314f8befd8f06 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 09:10:57 +0300 Subject: [PATCH 05/18] Add documentation for PDScheduler.Schedule function --- pkg/epp/scheduling/pd_scheduler.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go index 89770a292..b522d3645 100644 --- a/pkg/epp/scheduling/pd_scheduler.go +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -49,6 +49,12 @@ type PDScheduler struct { } // Schedule finds the target pod based on metrics and the requested lora adapter. +// PD schedule uses two base schedules to process request, configuration is currently loaded from environment variables. +// If request prompt is short enough (defined by threshold in the configuration) - use default behavior +// If request prompt is long enough to use prefill-decode process, +// 1 - find the pod for prefill, save it url in a special header, for this use Scheduler configured for this goal, which uses prefill filter +// and scorers according to configuration. +// 2 - find the pod for decode, use Scheduler configured for this goal, which uses decode filer and scorers defined in the configuration func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("pd-schedule", req) From b649586e85b015c57c667021c21b1218ca4757d8 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 09:13:08 +0300 Subject: [PATCH 06/18] Update names of prefill and decode filters to avoid spaces --- pkg/epp/scheduling/plugins/filter/pd_filter.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/epp/scheduling/plugins/filter/pd_filter.go b/pkg/epp/scheduling/plugins/filter/pd_filter.go index a70c864d5..cc4bdfc7c 100644 --- a/pkg/epp/scheduling/plugins/filter/pd_filter.go +++ b/pkg/epp/scheduling/plugins/filter/pd_filter.go @@ -23,7 +23,7 @@ import ( // //////////////////////////// // Prefill filter var PrefillFilter = &baseFilter{ - name: "prefill filter", + name: "prefill_filter", filter: prefillFilterFunc, } @@ -43,7 +43,7 @@ func prefillFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.P // //////////////////////////// // Decode filter var DecodeFilter = &baseFilter{ - name: "decode filter", + name: "decode_filter", filter: decodeFilterFunc, } From 16e4496f3b97c9dfd3f34fa2b8b4112ef1d6c3e8 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 09:16:25 +0300 Subject: [PATCH 07/18] Update comment for prefill/decode fitlers --- pkg/epp/scheduling/plugins/filter/pd_filter.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pkg/epp/scheduling/plugins/filter/pd_filter.go b/pkg/epp/scheduling/plugins/filter/pd_filter.go index cc4bdfc7c..fd4c5a8cc 100644 --- a/pkg/epp/scheduling/plugins/filter/pd_filter.go +++ b/pkg/epp/scheduling/plugins/filter/pd_filter.go @@ -20,8 +20,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) -// //////////////////////////// -// Prefill filter +// PrefillFilter - filters out all pods that are not marked as decode/both pod role var PrefillFilter = &baseFilter{ name: "prefill_filter", filter: prefillFilterFunc, @@ -40,8 +39,7 @@ func prefillFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.P return filteredPods } -// //////////////////////////// -// Decode filter +// DecodeFilter - fiters out all pods that are not marked as prefill pod role var DecodeFilter = &baseFilter{ name: "decode_filter", filter: decodeFilterFunc, From d58cb144f033c192b3049b44f1423250bf6bedf1 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 09:36:18 +0300 Subject: [PATCH 08/18] Change IsPDEnabled to PDEnabled --- pkg/epp/scheduling/pd_config.go | 4 ++-- pkg/epp/server/runserver.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/epp/scheduling/pd_config.go b/pkg/epp/scheduling/pd_config.go index 0fc2283c7..46c4cb7e2 100644 --- a/pkg/epp/scheduling/pd_config.go +++ b/pkg/epp/scheduling/pd_config.go @@ -42,7 +42,7 @@ var decodeConfig = &SchedulerConfig{ postSchedulePlugins: []plugins.PostSchedule{}, } -var IsPDEnabled = false +var PDEnabled = false var PromptLengthThreshold int func init() { @@ -53,7 +53,7 @@ func init() { loadDecodeConfiguration(ctx, loggerDebug) // set IsPDEnabled by environment - IsPDEnabled = getPDEnabledFromEnvironment(loggerDebug) + PDEnabled = getPDEnabledFromEnvironment(loggerDebug) PromptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug) } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index dbe63627d..9b8ea4177 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -139,7 +139,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { } var scheduler handlers.Scheduler - if scheduling.IsPDEnabled { + if scheduling.PDEnabled { scheduler = scheduling.NewPDScheduler(r.Datastore) } else { scheduler = scheduling.NewScheduler(r.Datastore) From 542984ecfe18eae00c14f1882e15ff339a1f0b1b Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 09:46:41 +0300 Subject: [PATCH 09/18] Fix typo in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index efdd5ceef..642849c13 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ To enable PD Scheduler, the following env var must be configured: export PD_ENABLED=true ``` -To define prompt length threshold (requests with length is longer than the value defined here will be processed using prefill-decode process), the following env var must be configured: +To define prompt length threshold (requests with length longer than the value defined here will be processed using prefill-decode process), the following env var must be configured: ``` export PD_PROMPT_LEN_THRESHOLD=10 ``` From 2ef8e19e1245bd45b057b5ac62e68453f331441b Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 10:51:52 +0300 Subject: [PATCH 10/18] Fix pd scheduler behavior for short promprts --- pkg/epp/scheduling/pd_scheduler.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go index b522d3645..48cc421b4 100644 --- a/pkg/epp/scheduling/pd_scheduler.go +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -31,14 +31,15 @@ const ( ) func NewPDScheduler(datastore Datastore) *PDScheduler { - return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig) + return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig, defaultConfig) } -func NewPDSchedulerWithConfig(datastore Datastore, prefillConfig *SchedulerConfig, decodeConfig *SchedulerConfig) *PDScheduler { +func NewPDSchedulerWithConfig(datastore Datastore, pConfig *SchedulerConfig, dConfig *SchedulerConfig, defConfig *SchedulerConfig) *PDScheduler { return &PDScheduler{ datastore: datastore, - prefillScheduler: NewSchedulerWithConfig(datastore, prefillConfig), - decodeScheduler: NewSchedulerWithConfig(datastore, decodeConfig), + prefillScheduler: NewSchedulerWithConfig(datastore, pConfig), + decodeScheduler: NewSchedulerWithConfig(datastore, dConfig), + defaultScheduler: NewSchedulerWithConfig(datastore, defConfig), } } @@ -46,6 +47,7 @@ type PDScheduler struct { datastore Datastore prefillScheduler *Scheduler decodeScheduler *Scheduler + defaultScheduler *Scheduler } // Schedule finds the target pod based on metrics and the requested lora adapter. @@ -60,7 +62,7 @@ func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*typ if len(req.Prompt) < PromptLengthThreshold { // prompt is short enough - use decode scheduling logic - return s.decodeScheduler.Schedule(ctx, req) + return s.defaultScheduler.Schedule(ctx, req) } pool, err := s.datastore.PoolGet() From c8e2aa6e65fc601f7f3fae1527f900ba2d64416b Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 11:27:47 +0300 Subject: [PATCH 11/18] Fix prefill/decode related text in readme --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 642849c13..3daad8eee 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ This project offers tools for AI Inference, enabling developers to build [Infere --- ## Temporary Fork Configuration -To enable KVCacheAwareScorer, the following env vars must be configured: +To enable KVCacheAwareScorer, the following environment variables must be configured: ``` export ENABLE_KVCACHE_AWARE_SCORER=true export KVCACHE_AWARE_SCORER_WEIGHT=1.0 @@ -17,45 +17,45 @@ export KVCACHE_INDEXER_REDIS_ADDR= export HF_TOKEN= ``` -To enable LoadAwareScorer, the following env vars must be configured: +To enable LoadAwareScorer, the following environment variables must be configured: ``` export ENABLE_LOAD_AWARE_SCORER=true export LOAD_AWARE_SCORER_WEIGHT=1.0 ``` -To enable PD Scheduler, the following env var must be configured: +To enable Prefill/Decode (PD) processing, the following environment variable must be configured: ``` export PD_ENABLED=true ``` -To define prompt length threshold (requests with length longer than the value defined here will be processed using prefill-decode process), the following env var must be configured: +To define the prompt length threshold (requests with a prompt longer than the value defined here will be processed using the prefill-decode process), the following environment variable must be configured: ``` export PD_PROMPT_LEN_THRESHOLD=10 ``` -Prefill scheduler configuration: +Prefill configuration: -To enable and configure kv cache scorer, the following env vars must be configured: +To enable and configure the kv cache scorer for prefill, the following environment variables must be configured: ``` export PREFILL_ENABLE_KVCACHE_AWARE_SCORER=true export PREFILL_KVCACHE_AWARE_SCORER_WEIGHT=1.0 ``` -To enable and configure load aware scorer, the following env vars must be configured: +To enable and configure the load aware scorer for prefill, the following environment variables must be configured: ``` export PREFILL_ENABLE_LOAD_AWARE_SCORER=true export PREFILL_LOAD_AWARE_SCORER_WEIGHT=1.0 ``` -Decode scheduler configuration: +Decode configuration: -To enable and configure kv cache scorer, the following env vars must be configured: +To enable and configure the kv cache scorer for decode, the following environment variables must be configured: ``` export DECODE_ENABLE_KVCACHE_AWARE_SCORER=true export DECODE_KVCACHE_AWARE_SCORER_WEIGHT=1.0 ``` -To enable and configure load aware scorer, the following env vars must be configured: +To enable and configure the load aware scorer for decode, the following environment variables must be configured: ``` export DECODE_ENABLE_LOAD_AWARE_SCORER=true export DECODE_LOAD_AWARE_SCORER_WEIGHT=1.0 From 5fb506e7f74b9db3a8db4970de0b2081180aa94e Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:02:13 +0300 Subject: [PATCH 12/18] Remove redundant filter creation of prefil/decode filters + make promptLengthThreshold local Add function for schedulerContext creation --- pkg/epp/scheduling/pd_config.go | 12 ++---------- pkg/epp/scheduling/pd_scheduler.go | 26 ++++++++++---------------- pkg/epp/scheduling/scheduler.go | 21 ++++++++++++++------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/pkg/epp/scheduling/pd_config.go b/pkg/epp/scheduling/pd_config.go index 46c4cb7e2..ccb4748d8 100644 --- a/pkg/epp/scheduling/pd_config.go +++ b/pkg/epp/scheduling/pd_config.go @@ -43,7 +43,7 @@ var decodeConfig = &SchedulerConfig{ } var PDEnabled = false -var PromptLengthThreshold int +var promptLengthThreshold int func init() { ctx := context.Background() @@ -54,25 +54,17 @@ func init() { // set IsPDEnabled by environment PDEnabled = getPDEnabledFromEnvironment(loggerDebug) - PromptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug) + promptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug) } func loadPrefillConfiguration(ctx context.Context, logger logr.Logger) { // add scorers addScorerByEnvironment(ctx, prefillConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger) addScorerByEnvironment(ctx, prefillConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger) - - // set filter - // TODO - do we want to keep default filters? - prefillConfig.filters = []plugins.Filter{filter.PrefillFilter} } func loadDecodeConfiguration(ctx context.Context, logger logr.Logger) { // add scorers addScorerByEnvironment(ctx, decodeConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger) addScorerByEnvironment(ctx, decodeConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger) - - // set filter - // TODO - do we want to keep default filters? - decodeConfig.filters = []plugins.Filter{filter.DecodeFilter} } diff --git a/pkg/epp/scheduling/pd_scheduler.go b/pkg/epp/scheduling/pd_scheduler.go index 48cc421b4..37822201a 100644 --- a/pkg/epp/scheduling/pd_scheduler.go +++ b/pkg/epp/scheduling/pd_scheduler.go @@ -23,7 +23,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" ) const ( @@ -51,30 +50,25 @@ type PDScheduler struct { } // Schedule finds the target pod based on metrics and the requested lora adapter. -// PD schedule uses two base schedules to process request, configuration is currently loaded from environment variables. -// If request prompt is short enough (defined by threshold in the configuration) - use default behavior -// If request prompt is long enough to use prefill-decode process, -// 1 - find the pod for prefill, save it url in a special header, for this use Scheduler configured for this goal, which uses prefill filter -// and scorers according to configuration. -// 2 - find the pod for decode, use Scheduler configured for this goal, which uses decode filer and scorers defined in the configuration +// PD scheduler uses three base schedulers to process requests, the overall configuration is currently loaded from environment variables. +// If the request prompt is short enough (defined by the threshold in the configuration) - use the default behavior +// If the request prompt is long enough to use prefill-decode process: +// 1 - find the pod for prefill, save its url in a special header. For this, use the Scheduler configured for this goal, which uses the prefill filter +// and scorers according to the configuration. +// 2 - find the pod for decode, use the Scheduler configured for this goal, which uses the decode filer and scorers defined in the configuration func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { logger := log.FromContext(ctx).WithValues("pd-schedule", req) - if len(req.Prompt) < PromptLengthThreshold { - // prompt is short enough - use decode scheduling logic + if len(req.Prompt) < promptLengthThreshold { + // the prompt is short enough - use the default scheduling logic return s.defaultScheduler.Schedule(ctx, req) } - pool, err := s.datastore.PoolGet() + sCtx, err := createSchedulerContext(ctx, req, s.datastore) if err != nil { - return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + return nil, err } - // Snapshot pod metrics from the datastore to: - // 1. Reduce concurrent access to the datastore. - // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber) - // prompt requires processing on two pods - prefill and decode // start with calculating of the prefill pod res, err := s.prefillScheduler.scheduleWithContext(ctx, sCtx, req, logger) diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index f7ab31f5a..bb48fa485 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -100,12 +100,8 @@ type Datastore interface { PodGetAll() []backendmetrics.PodMetrics } -// Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { - logger := log.FromContext(ctx).WithValues("request", req) - loggerDebug := logger.V(logutil.DEBUG) - - pool, err := s.datastore.PoolGet() +func createSchedulerContext(ctx context.Context, req *types.LLMRequest, datastore Datastore) (*types.SchedulingContext, error) { + pool, err := datastore.PoolGet() if err != nil { return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods } @@ -113,7 +109,18 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. - sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber) + return types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber), nil +} + +// Schedule finds the target pod based on metrics and the requested lora adapter. +func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { + logger := log.FromContext(ctx).WithValues("request", req) + loggerDebug := logger.V(logutil.DEBUG) + + sCtx, err := createSchedulerContext(ctx, req, s.datastore) + if err != nil { + return nil, err + } return s.scheduleWithContext(ctx, sCtx, req, loggerDebug) } From a993bc71285533923b83f9600db3d564ab332a3c Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:03:37 +0300 Subject: [PATCH 13/18] Fixes in readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3daad8eee..dc8921795 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ This project offers tools for AI Inference, enabling developers to build [Infere --- ## Temporary Fork Configuration -To enable KVCacheAwareScorer, the following environment variables must be configured: +To enable the KVCacheAwareScorer, the following environment variables must be configured: ``` export ENABLE_KVCACHE_AWARE_SCORER=true export KVCACHE_AWARE_SCORER_WEIGHT=1.0 @@ -17,7 +17,7 @@ export KVCACHE_INDEXER_REDIS_ADDR= export HF_TOKEN= ``` -To enable LoadAwareScorer, the following environment variables must be configured: +To enable the LoadAwareScorer, the following environment variables must be configured: ``` export ENABLE_LOAD_AWARE_SCORER=true export LOAD_AWARE_SCORER_WEIGHT=1.0 From 205f34acf2574a95a5b670524cf16e45274680cb Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:05:39 +0300 Subject: [PATCH 14/18] fix compilation prblem --- pkg/epp/scheduling/scheduler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index bb48fa485..b56d20ca7 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -109,7 +109,7 @@ func createSchedulerContext(ctx context.Context, req *types.LLMRequest, datastor // Snapshot pod metrics from the datastore to: // 1. Reduce concurrent access to the datastore. // 2. Ensure consistent data during the scheduling operation of a request. - return types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber), nil + return types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(datastore.PodGetAll()), pool.Spec.TargetPortNumber), nil } // Schedule finds the target pod based on metrics and the requested lora adapter. From c443a6bd9d2a37b887487ea3ea60c9120bbb16eb Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:05:56 +0300 Subject: [PATCH 15/18] add pd scheduler test --- pkg/epp/scheduling/pd_scheduler_test.go | 155 ++++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 pkg/epp/scheduling/pd_scheduler_test.go diff --git a/pkg/epp/scheduling/pd_scheduler_test.go b/pkg/epp/scheduling/pd_scheduler_test.go new file mode 100644 index 000000000..482e803bb --- /dev/null +++ b/pkg/epp/scheduling/pd_scheduler_test.go @@ -0,0 +1,155 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "fmt" + "testing" + + k8stypes "k8s.io/apimachinery/pkg/types" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// Tests the default scheduler configuration and expected behavior. +func TestPDSchedule(t *testing.T) { + // Set configuration + PDEnabled = true + promptLengthThreshold = 10 + prefillConfig.filters = []plugins.Filter{filter.PrefillFilter} + prefillConfig.scorers = map[plugins.Scorer]int{} + decodeConfig.filters = []plugins.Filter{filter.DecodeFilter} + decodeConfig.scorers = map[plugins.Scorer]int{} + + pod1 := &backendmetrics.FakePodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, + Address: "1.2.3.4", + Role: backendmetrics.Prefill, + }, + Metrics: &backendmetrics.Metrics{}, + } + pod2 := &backendmetrics.FakePodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, + Address: "5.6.7.8", + Role: backendmetrics.Decode, + }, + Metrics: &backendmetrics.Metrics{}, + } + + tests := []struct { + name string + req *types.LLMRequest + input []*backendmetrics.FakePodMetrics + wantRes *types.Result + err bool + }{ + { + name: "no pods in datastore", + req: &types.LLMRequest{ + Model: "any-model", + ResolvedTargetModel: "any-model", + Critical: true, + Prompt: "12345678901", + }, + input: []*backendmetrics.FakePodMetrics{}, + err: true, + }, + { + name: "one pod, short prompt", + req: &types.LLMRequest{ + Model: "critical", + ResolvedTargetModel: "critical", + Critical: true, + Prompt: "123", + }, + // pod2 will be picked because it is decode pod + input: []*backendmetrics.FakePodMetrics{pod1}, + wantRes: &types.Result{ + TargetPod: &types.ScoredPod{ + Pod: pod1, + }, + MutatedHeaders: map[string]string{}, + }, + }, + { + name: "1P1D", + req: &types.LLMRequest{ + Model: "critical", + ResolvedTargetModel: "critical", + Critical: true, + Prompt: "12345678901", + }, + // pod2 will be picked because it is decode pod + input: []*backendmetrics.FakePodMetrics{pod1, pod2}, + wantRes: &types.Result{ + TargetPod: &types.ScoredPod{ + Pod: pod2, + }, + MutatedHeaders: map[string]string{"x-prefiller-url": "http://1.2.3.4:0"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scheduler := NewPDScheduler(&fakeDataStore{pods: test.input}) + got, err := scheduler.Schedule(context.Background(), test.req) + + fmt.Printf("Test %s:\n", test.name) + fmt.Printf("Result: %#v\n", got) + fmt.Printf("Expected: %#v\n", test.wantRes) + + if test.err != (err != nil) { + t.Errorf("Unexpected error, got %v, want %v", err, test.err) + } + + if test.wantRes != nil && got != nil { + if !mapsEqual(test.wantRes.MutatedHeaders, got.MutatedHeaders) { + fmt.Printf("Mutated headers are not the same\n") + t.Errorf("Mutated headers are not the same\n") + } + if got.TargetPod.GetPod().String() != test.wantRes.TargetPod.GetPod().String() { + fmt.Printf("target pod is not the same\n") + fmt.Printf("wanted: %s\n", test.wantRes.TargetPod.String()) + fmt.Printf("got: %s\n", got.TargetPod.String()) + t.Errorf("Tager pod is not the same") + } + } + + // if diff := cmp.Diff(test.wantRes, got); diff != "" { + // t.Errorf("Unexpected output (-want +got): %v", diff) + // } + }) + } +} + +func mapsEqual(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if bv, ok := b[k]; !ok || bv != v { + return false + } + } + return true +} From e8d9f6491a3a93e1fd1be594a310199c659e3687 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:12:47 +0300 Subject: [PATCH 16/18] add postResponse plugins array to prefile and decode config --- pkg/epp/scheduling/pd_config.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/epp/scheduling/pd_config.go b/pkg/epp/scheduling/pd_config.go index ccb4748d8..107ef88e6 100644 --- a/pkg/epp/scheduling/pd_config.go +++ b/pkg/epp/scheduling/pd_config.go @@ -33,6 +33,7 @@ var prefillConfig = &SchedulerConfig{ scorers: map[plugins.Scorer]int{}, picker: picker.NewMaxScorePicker(), postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, } var decodeConfig = &SchedulerConfig{ preSchedulePlugins: []plugins.PreSchedule{}, @@ -40,6 +41,7 @@ var decodeConfig = &SchedulerConfig{ scorers: map[plugins.Scorer]int{}, picker: picker.NewMaxScorePicker(), postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, } var PDEnabled = false From 2d0de9ee98e5622760238cabe9ea3feafc2e3b47 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:17:01 +0300 Subject: [PATCH 17/18] fix comment in test --- pkg/epp/scheduling/pd_scheduler_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/epp/scheduling/pd_scheduler_test.go b/pkg/epp/scheduling/pd_scheduler_test.go index 482e803bb..d1946ce42 100644 --- a/pkg/epp/scheduling/pd_scheduler_test.go +++ b/pkg/epp/scheduling/pd_scheduler_test.go @@ -81,7 +81,7 @@ func TestPDSchedule(t *testing.T) { Critical: true, Prompt: "123", }, - // pod2 will be picked because it is decode pod + // pod1 will be picked because it is the only one pod input: []*backendmetrics.FakePodMetrics{pod1}, wantRes: &types.Result{ TargetPod: &types.ScoredPod{ @@ -98,7 +98,7 @@ func TestPDSchedule(t *testing.T) { Critical: true, Prompt: "12345678901", }, - // pod2 will be picked because it is decode pod + // pod2 will be picked because it is the decode pod input: []*backendmetrics.FakePodMetrics{pod1, pod2}, wantRes: &types.Result{ TargetPod: &types.ScoredPod{ From 6c4160005a79dce4b0c7a7b5100f2d9d18061535 Mon Sep 17 00:00:00 2001 From: Maya Barnea Date: Mon, 5 May 2025 12:30:12 +0300 Subject: [PATCH 18/18] fix pd-scheduler test --- pkg/epp/scheduling/pd_scheduler_test.go | 57 ++++++++++++------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/pkg/epp/scheduling/pd_scheduler_test.go b/pkg/epp/scheduling/pd_scheduler_test.go index d1946ce42..1cec19433 100644 --- a/pkg/epp/scheduling/pd_scheduler_test.go +++ b/pkg/epp/scheduling/pd_scheduler_test.go @@ -21,6 +21,7 @@ import ( "fmt" "testing" + "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" @@ -54,6 +55,28 @@ func TestPDSchedule(t *testing.T) { }, Metrics: &backendmetrics.Metrics{}, } + wantPod1 := &types.PodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, + Address: "1.2.3.4", + Role: backendmetrics.Prefill, + }, + Metrics: &backendmetrics.Metrics{ + ActiveModels: map[string]int{}, + WaitingModels: map[string]int{}, + }, + } + wantPod2 := &types.PodMetrics{ + Pod: &backendmetrics.Pod{ + NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, + Address: "5.6.7.8", + Role: backendmetrics.Decode, + }, + Metrics: &backendmetrics.Metrics{ + ActiveModels: map[string]int{}, + WaitingModels: map[string]int{}, + }, + } tests := []struct { name string @@ -85,7 +108,7 @@ func TestPDSchedule(t *testing.T) { input: []*backendmetrics.FakePodMetrics{pod1}, wantRes: &types.Result{ TargetPod: &types.ScoredPod{ - Pod: pod1, + Pod: wantPod1, }, MutatedHeaders: map[string]string{}, }, @@ -102,7 +125,8 @@ func TestPDSchedule(t *testing.T) { input: []*backendmetrics.FakePodMetrics{pod1, pod2}, wantRes: &types.Result{ TargetPod: &types.ScoredPod{ - Pod: pod2, + Pod: wantPod2, + Score: 0.0, }, MutatedHeaders: map[string]string{"x-prefiller-url": "http://1.2.3.4:0"}, }, @@ -122,34 +146,9 @@ func TestPDSchedule(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if test.wantRes != nil && got != nil { - if !mapsEqual(test.wantRes.MutatedHeaders, got.MutatedHeaders) { - fmt.Printf("Mutated headers are not the same\n") - t.Errorf("Mutated headers are not the same\n") - } - if got.TargetPod.GetPod().String() != test.wantRes.TargetPod.GetPod().String() { - fmt.Printf("target pod is not the same\n") - fmt.Printf("wanted: %s\n", test.wantRes.TargetPod.String()) - fmt.Printf("got: %s\n", got.TargetPod.String()) - t.Errorf("Tager pod is not the same") - } + if diff := cmp.Diff(test.wantRes, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) } - - // if diff := cmp.Diff(test.wantRes, got); diff != "" { - // t.Errorf("Unexpected output (-want +got): %v", diff) - // } }) } } - -func mapsEqual(a, b map[string]string) bool { - if len(a) != len(b) { - return false - } - for k, v := range a { - if bv, ok := b[k]; !ok || bv != v { - return false - } - } - return true -}