diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 9fd401d4e..e674f1c20 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -34,6 +34,7 @@ import ( "k8s.io/client-go/rest" "k8s.io/component-base/metrics/legacyregistry" ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" @@ -43,7 +44,13 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -107,8 +114,22 @@ var ( "Prometheus metric for the LoRA info metrics (must be in vLLM label format).") setupLog = ctrl.Log.WithName("setup") + + // Environment variables + schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog) + prefixCacheScheduling = envutil.GetEnvString("ENABLE_PREFIX_CACHE_SCHEDULING", "false", setupLog) ) +func loadPrefixCacheConfig() prefix.Config { + baseLogger := log.Log.WithName("env-config") + + return prefix.Config{ + HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger), + MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger), + LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger), + } +} + func main() { if err := run(); err != nil { os.Exit(1) @@ -172,6 +193,27 @@ func run() error { datastore := datastore.NewDatastore(ctx, pmf) scheduler := scheduling.NewScheduler(datastore) + if schedulerV2 == "true" { + queueScorerWeight := envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog) + kvCacheScorerWeight := envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog) + scorers := map[plugins.Scorer]int{ + &scorer.QueueScorer{}: queueScorerWeight, + &scorer.KVCacheScorer{}: kvCacheScorerWeight, + } + schedConfigOpts := []scheduling.ConfigOption{} + if prefixCacheScheduling == "true" { + prefixScorerWeight := envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", prefix.DefaultScorerWeight, setupLog) + schedConfigOpts = append(schedConfigOpts, scheduling.AddPrefixPlugin(loadPrefixCacheConfig(), prefixScorerWeight)) + } + schedulerConfig := scheduling.NewSchedulerConfig( + []plugins.PreSchedule{}, + []plugins.Filter{filter.NewSheddableCapacityFilter()}, + scorers, + picker.NewMaxScorePicker(), + []plugins.PostSchedule{}, + schedConfigOpts...) + scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig) + } serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 6cc0cdb83..84f0f1f9a 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -209,6 +209,40 @@ var ( []string{"plugin_type", "plugin_name"}, ) + // Prefix indexer Metrics + PrefixCacheSize = compbasemetrics.NewGaugeVec( + &compbasemetrics.GaugeOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_size", + Help: "Size of the prefix indexer.", + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + + PrefixCacheHitRatio = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_hit_ratio", + Help: "Ratio of prefix length matched to total prefix length in the cache lookup.", + // Buckets from 0.0 to 1.0 in increments + Buckets: []float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + + PrefixCacheHitLength = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceExtension, + Name: "prefix_indexer_hit_bytes", + Help: "Length of the prefix match in number of bytes in the cache lookup.", + Buckets: []float64{0, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{}, + ) + // Info Metrics InferenceExtensionInfo = compbasemetrics.NewGaugeVec( &compbasemetrics.GaugeOpts{ @@ -244,6 +278,10 @@ func Register() { legacyregistry.MustRegister(SchedulerE2ELatency) legacyregistry.MustRegister(InferenceExtensionInfo) + + legacyregistry.MustRegister(PrefixCacheSize) + legacyregistry.MustRegister(PrefixCacheHitRatio) + legacyregistry.MustRegister(PrefixCacheHitLength) }) } @@ -352,6 +390,24 @@ func RecordSchedulerE2ELatency(duration time.Duration) { SchedulerE2ELatency.WithLabelValues().Observe(duration.Seconds()) } +// RecordPrefixCacheSize records the size of the prefix indexer in megabytes. +func RecordPrefixCacheSize(size int64) { + PrefixCacheSize.WithLabelValues().Set(float64(size)) +} + +// RecordPrefixCacheMatch records both the hit ratio and hit length for a prefix indexer match. +// matchedLength is the number of characters that matched, and totalLength is the total prefix length. +func RecordPrefixCacheMatch(matchedLength, totalLength int) { + // Record the hit length metric + PrefixCacheHitLength.WithLabelValues().Observe(float64(matchedLength)) + + // Record the hit ratio metric if totalLength is positive + if totalLength > 0 { + ratio := float64(matchedLength) / float64(totalLength) + PrefixCacheHitRatio.WithLabelValues().Observe(ratio) + } +} + func RecordInferenceExtensionInfo() { if CommitSHA != "" { InferenceExtensionInfo.WithLabelValues(CommitSHA).Set(1) diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index 3a8136a08..4ad6f96e1 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -664,3 +664,106 @@ func TestSchedulerE2ELatency(t *testing.T) { }) } } + +func TestPrefixCacheMetrics(t *testing.T) { + const ( + PrefixCacheSizeMetric = InferenceExtension + "_prefix_indexer_size" + PrefixCacheHitRatioMetric = InferenceExtension + "_prefix_indexer_hit_ratio" + PrefixCacheHitLengthMetric = InferenceExtension + "_prefix_indexer_hit_bytes" + ) + + type cacheMatchRecord struct { + matchedLength int + totalLength int + } + + scenario := struct { + name string + cacheSizes []int64 + cacheMatches []cacheMatchRecord + }{ + name: "multiple cache metrics", + cacheSizes: []int64{1024, 2048, 4096}, + cacheMatches: []cacheMatchRecord{ + { + matchedLength: 5, + totalLength: 10, + }, + { + matchedLength: 0, + totalLength: 10, + }, + { + matchedLength: 10, + totalLength: 10, + }, + { + matchedLength: 7, + totalLength: 10, + }, + { + matchedLength: 64, + totalLength: 128, + }, + { + matchedLength: 0, + totalLength: 128, + }, + }, + } + + Register() + t.Run(scenario.name, func(t *testing.T) { + // Record cache size metrics + for _, size := range scenario.cacheSizes { + RecordPrefixCacheSize(size) + } + + // Record cache match metrics (both hit ratio and hit length) + for _, match := range scenario.cacheMatches { + RecordPrefixCacheMatch(match.matchedLength, match.totalLength) + } + + // Verify cache size metrics + wantCacheSizeMetrics, err := os.Open("testdata/prefix_indexer_size_metric") + defer func() { + if err := wantCacheSizeMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantCacheSizeMetrics, PrefixCacheSizeMetric); err != nil { + t.Error(err) + } + + // Verify hit ratio metrics + wantHitRatioMetrics, err := os.Open("testdata/prefix_indexer_hit_ratio_metric") + defer func() { + if err := wantHitRatioMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantHitRatioMetrics, PrefixCacheHitRatioMetric); err != nil { + t.Error(err) + } + + // Verify hit length metrics + wantHitLengthMetrics, err := os.Open("testdata/prefix_indexer_hit_bytes_metric") + defer func() { + if err := wantHitLengthMetrics.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantHitLengthMetrics, PrefixCacheHitLengthMetric); err != nil { + t.Error(err) + } + }) +} diff --git a/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric b/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric new file mode 100644 index 000000000..86b48724e --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_hit_bytes_metric @@ -0,0 +1,19 @@ +# HELP inference_extension_prefix_indexer_hit_bytes [ALPHA] Length of the prefix match in number of bytes in the cache lookup. +# TYPE inference_extension_prefix_indexer_hit_bytes histogram +inference_extension_prefix_indexer_hit_bytes_bucket{le="0"} 2 +inference_extension_prefix_indexer_hit_bytes_bucket{le="16"} 5 +inference_extension_prefix_indexer_hit_bytes_bucket{le="32"} 5 +inference_extension_prefix_indexer_hit_bytes_bucket{le="64"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="128"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="256"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="512"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="1024"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="2048"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="4096"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="8192"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="16384"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="32768"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="65536"} 6 +inference_extension_prefix_indexer_hit_bytes_bucket{le="+Inf"} 6 +inference_extension_prefix_indexer_hit_bytes_sum 86 +inference_extension_prefix_indexer_hit_bytes_count 6 diff --git a/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric b/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric new file mode 100644 index 000000000..e94827cb6 --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_hit_ratio_metric @@ -0,0 +1,16 @@ +# HELP inference_extension_prefix_indexer_hit_ratio [ALPHA] Ratio of prefix length matched to total prefix length in the cache lookup. +# TYPE inference_extension_prefix_indexer_hit_ratio histogram +inference_extension_prefix_indexer_hit_ratio_bucket{le="0"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.1"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.2"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.3"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.4"} 2 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.5"} 4 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.6"} 4 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.7"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.8"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="0.9"} 5 +inference_extension_prefix_indexer_hit_ratio_bucket{le="1"} 6 +inference_extension_prefix_indexer_hit_ratio_bucket{le="+Inf"} 6 +inference_extension_prefix_indexer_hit_ratio_sum 2.7 +inference_extension_prefix_indexer_hit_ratio_count 6 diff --git a/pkg/epp/metrics/testdata/prefix_indexer_size_metric b/pkg/epp/metrics/testdata/prefix_indexer_size_metric new file mode 100644 index 000000000..9799b1729 --- /dev/null +++ b/pkg/epp/metrics/testdata/prefix_indexer_size_metric @@ -0,0 +1,3 @@ +# HELP inference_extension_prefix_indexer_size [ALPHA] Size of the prefix indexer. +# TYPE inference_extension_prefix_indexer_size gauge +inference_extension_prefix_indexer_size{} 4096 diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index a4f4c2950..e321ca2bf 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -18,18 +18,23 @@ package scheduling import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix" ) // NewSchedulerConfig creates a new SchedulerConfig object with the given plugins. func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int, - picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule) *SchedulerConfig { - return &SchedulerConfig{ + picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule, opts ...ConfigOption) *SchedulerConfig { + config := &SchedulerConfig{ preSchedulePlugins: preSchedulePlugins, filters: filters, scorers: scorers, picker: picker, postSchedulePlugins: postSchedulePlugins, } + for _, opt := range opts { + opt(config) + } + return config } // SchedulerConfig provides a configuration for the scheduler which influence routing decisions. @@ -40,3 +45,16 @@ type SchedulerConfig struct { picker plugins.Picker postSchedulePlugins []plugins.PostSchedule } + +type ConfigOption func(*SchedulerConfig) + +// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/813): Replace this +// with a more generic way to add plugins. +func AddPrefixPlugin(prefixConfig prefix.Config, weight int) ConfigOption { + return func(cfg *SchedulerConfig) { + prefixPlugin := prefix.New(prefixConfig) + cfg.preSchedulePlugins = append(cfg.preSchedulePlugins, prefixPlugin) + cfg.postSchedulePlugins = append(cfg.postSchedulePlugins, prefixPlugin) + cfg.scorers[prefixPlugin] = weight + } +} diff --git a/pkg/epp/scheduling/plugins/prefix/indexer.go b/pkg/epp/scheduling/plugins/prefix/indexer.go new file mode 100644 index 000000000..2017ba175 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/indexer.go @@ -0,0 +1,173 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "context" + "sync" + "time" + "unsafe" + + "container/list" + + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func newIndexer(maxCacheSize int) *indexer { + t := &indexer{ + maxCacheSize: maxCacheSize, + table: make(map[BlockHash]map[ServerID]*list.Element), + ll: list.New(), + } + go t.ReportCacheSize(time.Second) + return t +} + +// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that +// prefix cached . +type indexer struct { + mu sync.RWMutex + maxCacheSize int + table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server + ll *list.List // LinkedList to keep track of the order of entries +} + +// value is the value stored in the linked list. +type value struct { + server ServerID + hash BlockHash +} + +// Get returns the set of servers that have the given prefix hash cached. +func (i *indexer) Get(hash BlockHash) map[ServerID]bool { + i.mu.RLock() + defer i.mu.RUnlock() + res := map[ServerID]bool{} + for server := range i.table[hash] { + res[server] = true + } + return res +} + +// Add adds a list of prefix hashes of a single request to the server the request was sent to. +// The intuition is that this server is likely to have the prefix cached, so next time a request +// sharing the longest prefix should be sent to the same server to take advantage of the cache hit. +func (i *indexer) Add(hashes []BlockHash, server ServerID) { + i.mu.Lock() + defer i.mu.Unlock() + for _, hash := range hashes { + i.add(hash, server) + } +} + +func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) { + servers, ok := i.table[hash] + if !ok { + return nil, false + } + e, ok := servers[server] + return e, ok +} + +func (i *indexer) add(hash BlockHash, server ServerID) { + e, exists := i.check(hash, server) + if exists { + i.ll.MoveToBack(e) + } else { + i.create(hash, server) + } +} + +func (i *indexer) create(hash BlockHash, server ServerID) { + for i.ll.Len() >= i.maxCacheSize { + // Evict the least recently used entry if we've exceeded the max cache size + i.evict() + } + + if _, ok := i.table[hash]; !ok { + i.table[hash] = make(map[ServerID]*list.Element) + } + v := &value{ + server: server, + hash: hash, + } + e := i.ll.PushBack(v) + i.table[hash][server] = e +} + +// evict removes the least recently used entry from the cache +func (i *indexer) evict() { + oldestNode := i.ll.Front() + if oldestNode == nil { + return + } + i.ll.Remove(oldestNode) + + v := oldestNode.Value.(*value) + hash := v.hash + server := v.server + // Remove from the hash map + serverMap := i.table[hash] + delete(serverMap, server) + + // If this was the last server for this hash, remove the hash entry entirely + if len(serverMap) == 0 { + delete(i.table, hash) + } + + log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server) +} + +// ReportCacheSize starts a goroutine that periodically reports the cache size metric +func (i *indexer) ReportCacheSize(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + i.mu.RLock() + metrics.RecordPrefixCacheSize(int64(i.ll.Len())) + log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000) + i.mu.RUnlock() + } +} + +// estimateEntrySize estimates the memory size of a cache entry in bytes. +func (i *indexer) estimateEntrySize() int { + size := 0 + + // Estimate the size of a node in the linked list. + // First get the size of the node struct via unsafe.Sizeof. + // The prev and next pointers are 8 bytes each on a 64-bit system. + // The BlockHash is a uint64, which is 8 bytes. + // The ServerID is a NamespacedName, which contains two strings (Name and Namespace). + // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length). + // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes. + size += int(unsafe.Sizeof(value{})) + // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName). + size += 2 * 63 + + // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored. + size += 8 // Size of the BlockHash (uint64). + size += 2 * 16 // Size of the ServerID string headers (NamespacedName). + size += 2 * 63 // Size of the Name and Namespace strings in ServerID. + size += 8 // Size of the pointer to the node in the hash map. + + // Based on the above estimates, the estimated size of an entry is: + // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes. + return size +} diff --git a/pkg/epp/scheduling/plugins/prefix/indexer_test.go b/pkg/epp/scheduling/plugins/prefix/indexer_test.go new file mode 100644 index 000000000..596625d10 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/indexer_test.go @@ -0,0 +1,45 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package prefix + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIndexer_AddAndGet(t *testing.T) { + cache := newIndexer(2) + + hash1 := BlockHash(1) + server := ServerID{Namespace: "default", Name: "server1"} + + // Add an entry to the cache + cache.Add([]BlockHash{hash1}, server) + + // Retrieve the entry + assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry") + servers := cache.Get(hash1) + assert.Contains(t, servers, server, "Cache should contain the added server") + + // Add another entry to the cache, the cache size should be incremented to 2. + cache.Add([]BlockHash{BlockHash(2)}, server) + assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry") + + // Add another entry to the cache, which should evict the first one due to max size. + cache.Add([]BlockHash{BlockHash(3)}, server) + assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry") +} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin.go b/pkg/epp/scheduling/plugins/prefix/plugin.go new file mode 100644 index 000000000..6d7f03c10 --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/plugin.go @@ -0,0 +1,204 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package prefix + +import ( + "encoding/binary" + "fmt" + + "github.com/cespare/xxhash/v2" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + DefaultScorerWeight = 1 + // Attempt to return DefaultNumServersToMatch servers with their longest prefix match length. + // Why not just return the server with longest prefix match? + // It may not be the optimal choice, e.g., it may have a high queue depth. + // We optimistically search more than one to give more candidates for the scheduler to choose. + DefaultNumServersToMatch = 2 + // vLLM default token block size is 16, and a good guess of average characters per token is 4. + DefaultHashBlockSize = 64 + // The maximum number of blocks to match. Two long requests with the same prefix up to this + // limit will be indistinguishable. + // This parameter provides a trade-off between cache size, prefix matching speed and matching + // accuracy. Use a small value if most requests are short to reduce cache size and speed up the + // matching process. Use a large value if most requests are long to increase the matching accuracy. + DefaultMaxPrefixBlocks = 128 + // The indexer is an approximation to the actual prefix cache state on the model servers. + // A small capacity ensures a high accuracy of cache hit on the model server, but it will + // increase the chance of false negatives. A high capacity does the opposite. + // To properly size this, consider the sum of the total number of cache entries on all model + // servers. Consider the llama3 8B model on 3 H100 80GB GPUs. The size of the model weight is + // about 16GB. Assume 50% of the remaining HBM is used for caching prefixes, we have 32GB. Each + // token is about 128KB in size, so we can cache 250K tokens. Using the default block size of 16 + // in vLLM, we will have 250K / 16 = 15.6K blocks. In total we have 15.6K * 3 = 46.8K blocks, or + // roughly 50K. + // How much memory space does it require to hold the 50K block hashes? + // According to the estimates in indexer.estimateEntrySize(), the size of each entry is + // approximately 348 bytes. So in total we have 50K * 348 = 17.4MB. + DefaultLRUIndexerCapacity = 50000 +) + +type Config struct { + // The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests + // with length shorter than the block size will be ignored. + HashBlockSize int + // MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will + // be ignored. + MaxPrefixBlocksToMatch int + // Max (approximate) size of the LRU indexer in number of entries. + LRUIndexerCapacity int +} + +type Plugin struct { + Config + indexer Indexer +} + +type Indexer interface { + Get(hash BlockHash) map[ServerID]bool + Add(hashes []BlockHash, server ServerID) +} + +// This is the state of this plugin to be used during a scheduling cycle. +type SchedulingContextState struct { + // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks. + PrefixHashes []BlockHash + // A map of server to its longest prefix cache match length. + PrefixCacheServers map[ServerID]int +} + +// BlockHash is a hash of the block of request body. +type BlockHash uint64 + +type ServerID k8stypes.NamespacedName + +func (s ServerID) String() string { + return k8stypes.NamespacedName(s).String() +} + +func New(config Config) *Plugin { + m := &Plugin{ + Config: config, + indexer: newIndexer(config.LRUIndexerCapacity), + } + return m +} + +func (m *Plugin) Name() string { + return "prefixCache" +} + +func (m *Plugin) PreSchedule(ctx *types.SchedulingContext) { + hashes := hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch) + state := SchedulingContextState{ + PrefixHashes: hashes, + PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, DefaultNumServersToMatch), + } + ctx.SetPluginState(types.PluginName(m.Name()), state) + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes) +} + +// If a request was routed to a server, record it in the cache: +func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { + targetPod := res.TargetPod.GetPod() + state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState) + m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName)) + total := len(state.PrefixHashes) + matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)] + metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize) +} + +func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + state := ctx.GetPluginState(types.PluginName(m.Name())).(SchedulingContextState) + total := len(state.PrefixHashes) + podScoreFunc := func(pod types.Pod) float64 { + if total == 0 { + return 0 + } + matchLen := state.PrefixCacheServers[ServerID(pod.GetPod().NamespacedName)] + return float64(matchLen) / float64(total) + } + + scores := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + scores[pod] = podScoreFunc(pod) + } + return scores +} + +// matchLongestPrefix returns a map of servers and length of prefix that each server caches. +func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, hashes []BlockHash, numServers int) map[ServerID]int { + if numServers > len(ctx.PodsSnapshot) { + numServers = len(ctx.PodsSnapshot) + } + res := make(map[ServerID]int) + // Use a greedy strategy to search from the longest prefix. + // NOTE: It's possible to further optimize this with a binary search. + for i := len(hashes) - 1; i >= 0 && len(res) < numServers; i-- { + hash := hashes[i] + cachedServers := m.indexer.Get(hash) + if len(cachedServers) > 0 { + ctx.Logger.V(logutil.DEBUG).Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i) + for server := range cachedServers { + // Update servers with their longest prefix match. + // If we already found this server with longer prefix match, don't update it. + if _, ok := res[server]; !ok { + res[server] = i + 1 + } + } + } + } + return res +} + +// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block. +// hash(0) is the hash of the model name, since different models generally don't share prefix cache. +// For block i, hash(i) = hash(block i content, hash(i-1)). +func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlocks int) []BlockHash { + prompt := []byte(ctx.Req.Prompt) + if len(prompt) < cacheBlockSize { + ctx.Logger.V(logutil.DEBUG).Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize) + return nil + } + if len(prompt) > cacheBlockSize*maxPrefixBlocks { + ctx.Logger.V(logutil.DEBUG).Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize) + prompt = prompt[:maxPrefixBlocks*cacheBlockSize] + } + // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model. + // If the last block is smaller than cacheBlockSize, it will be ignored. + res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize) + // Add the model to the first block hash so that different models have different hashes even with the same body. + res = append(res, BlockHash(xxhash.Sum64String(ctx.Req.ResolvedTargetModel))) + for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize { + block := prompt[i : i+cacheBlockSize] + prevBlockHash := res[len(res)-1] + block = append(block, toBytes(prevBlockHash)...) + res = append(res, BlockHash(xxhash.Sum64(block))) + } + return res +} + +func toBytes(i BlockHash) []byte { + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, uint64(i)) + return bytes +} diff --git a/pkg/epp/scheduling/plugins/prefix/plugin_test.go b/pkg/epp/scheduling/plugins/prefix/plugin_test.go new file mode 100644 index 000000000..9aa1dbf1c --- /dev/null +++ b/pkg/epp/scheduling/plugins/prefix/plugin_test.go @@ -0,0 +1,137 @@ +package prefix + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestPrefixPlugin(t *testing.T) { + config := Config{ + HashBlockSize: 4, + MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks, + LRUIndexerCapacity: DefaultLRUIndexerCapacity, + } + plugin := New(config) + + pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}} + pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}} + pods := []types.Pod{pod1, pod2} + + // First request. + req1 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaaaa", + } + ctx := types.NewSchedulingContext(context.Background(), req1, pods) + plugin.PreSchedule(ctx) + state := ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 6, hash block size is 4, the last 2 characters are ignored. + // Total hashes = 2 (the first one is for the model) + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") + + // Updated to use the new Score method signature + scores := plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod1 was picked. + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // Second request doesn't share any prefix with first one. It should be added to the cache but + // the pod score should be 0. + req2 := &types.LLMRequest{ + Model: "test-model2", + ResolvedTargetModel: "test-model2", + Prompt: "bbbbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req2, pods) + plugin.PreSchedule(ctx) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 6, hash block size is 4, the last 2 characters are ignored. + // Total hashes = 2 (the first one is for the model) + assert.Equal(t, 2, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + // Simulate pod2 was picked. + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod2}) + + // Third request shares partial prefix with first one. + req3 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaabbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req3, pods) + plugin.PreSchedule(ctx) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 8, hash block size is 4, so 2 hashes will be calculated. + // Total hashes = 3 (the first one is for the model) + assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(2)/float64(3), scores[pod1], "score should be 2/3 - the model and the first prefix block match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // 4th request is same as req3 except the model is different, still no match. + req4 := &types.LLMRequest{ + Model: "test-model-new", + ResolvedTargetModel: "test-model-new", + Prompt: "aaaabbbb", + } + ctx = types.NewSchedulingContext(context.Background(), req4, pods) + plugin.PreSchedule(ctx) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 8, hash block size is 4, so 2 hashes will be calculated. + // Total hashes = 3 (the first one is for the model) + assert.Equal(t, 3, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 0, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, float64(0), scores[pod1], "score for pod1") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) + + // 5th request shares partial prefix with 3rd one. + req5 := &types.LLMRequest{ + Model: "test-model1", + ResolvedTargetModel: "test-model1", + Prompt: "aaaabbbbcccc", + } + ctx = types.NewSchedulingContext(context.Background(), req5, pods) + plugin.PreSchedule(ctx) + state = ctx.GetPluginState(types.PluginName(plugin.Name())).(SchedulingContextState) + t.Logf("Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers) + // Input size is 12, hash block size is 4, so 3 hashes will be calculated. + // Total hashes = 4 (the first one is for the model) + assert.Equal(t, 4, len(state.PrefixHashes), "number of hashes is incorrect") + assert.Equal(t, 1, len(state.PrefixCacheServers), "pod1 should have cached the aaaa prefix") + + // Updated to use the new Score method signature + scores = plugin.Score(ctx, pods) + assert.Equal(t, 0.75, scores[pod1], "score should be 0.75 - the model and the first 2 prefix blocks match") + assert.Equal(t, float64(0), scores[pod2], "score for pod2") + + plugin.PostSchedule(ctx, &types.Result{TargetPod: pod1}) +} diff --git a/pkg/epp/scheduling/plugins/scorer/kvcache.go b/pkg/epp/scheduling/plugins/scorer/kvcache.go index 0877691d1..dbb6079dc 100644 --- a/pkg/epp/scheduling/plugins/scorer/kvcache.go +++ b/pkg/epp/scheduling/plugins/scorer/kvcache.go @@ -20,6 +20,10 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +const ( + DefaultKVCacheScorerWeight = 1 +) + type KVCacheScorer struct{} func (ss *KVCacheScorer) Name() string { diff --git a/pkg/epp/scheduling/plugins/scorer/queue.go b/pkg/epp/scheduling/plugins/scorer/queue.go index 3df9d4140..bbe6b6961 100644 --- a/pkg/epp/scheduling/plugins/scorer/queue.go +++ b/pkg/epp/scheduling/plugins/scorer/queue.go @@ -22,6 +22,10 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" ) +const ( + DefaultQueueScorerWeight = 1 +) + type QueueScorer struct{} func (q *QueueScorer) Name() string { diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 795ef65d2..daf27bf83 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -19,6 +19,7 @@ package types import ( "context" "fmt" + "sync" "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" @@ -62,8 +63,26 @@ type SchedulingContext struct { Logger logr.Logger Req *LLMRequest PodsSnapshot []Pod + // PluginState can be used by plugins to store state during a scheduling cycle, to communicate + // between different extension points. + PluginState map[PluginName]any + pluginStateMu *sync.RWMutex } +func (sc *SchedulingContext) GetPluginState(pluginName PluginName) any { + sc.pluginStateMu.RLock() + defer sc.pluginStateMu.RUnlock() + return sc.PluginState[pluginName] +} + +func (sc *SchedulingContext) SetPluginState(pluginName PluginName, state any) { + sc.pluginStateMu.Lock() + defer sc.pluginStateMu.Unlock() + sc.PluginState[pluginName] = state +} + +type PluginName string + func (pm *PodMetrics) String() string { if pm == nil { return "" @@ -87,10 +106,12 @@ type PodMetrics struct { func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { logger := log.FromContext(ctx).WithValues("request", req) return &SchedulingContext{ - Context: ctx, - Logger: logger, - Req: req, - PodsSnapshot: pods, + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + PluginState: make(map[PluginName]any), + pluginStateMu: &sync.RWMutex{}, } }