Skip to content

Commit 09f7448

Browse files
authored
Merge pull request kubernetes-sigs#48 from oglok/prefix_scorer
Prefix Aware Scorer
2 parents b7689d0 + a0e02c0 commit 09f7448

File tree

6 files changed

+548
-2
lines changed

6 files changed

+548
-2
lines changed

pkg/epp/scheduling/local_config.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ import (
2929
const (
3030
kvCacheScorerEnablementEnvVar = "ENABLE_KVCACHE_AWARE_SCORER"
3131
loadAwareScorerEnablementEnvVar = "ENABLE_LOAD_AWARE_SCORER"
32+
prefixScorerEnablementEnvVar = "ENABLE_PREFIX_AWARE_SCORER"
3233
pdFilterEnablementEnvVar = "ENABLE_PD_FILTER"
3334

3435
kvCacheScorerWeightEnvVar = "KVCACHE_AWARE_SCORER_WEIGHT"
3536
loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT"
37+
prefixScorerWeightEnvVar = "PREFIX_AWARE_SCORER_WEIGHT"
3638
)
3739

3840
func init() {
@@ -44,6 +46,7 @@ func setDefaultConfig() {
4446
// this configuration is a temporary state, it should be better streamlined.
4547
setLoadAwareScorer()
4648
setKVCacheAwareScorer()
49+
setPrefixScorer()
4750

4851
defaultConfig.picker = picker.NewMaxScorePicker()
4952
}
@@ -81,3 +84,20 @@ func setKVCacheAwareScorer() {
8184
defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight
8285
loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight)
8386
}
87+
88+
func setPrefixScorer() {
89+
ctx := context.Background()
90+
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)
91+
92+
if envutil.GetEnvString(prefixScorerEnablementEnvVar, "false", loggerDebug) != "true" {
93+
loggerDebug.Info("Skipping PrefixScorer creation as it is not enabled")
94+
return
95+
}
96+
97+
prefixScorerWeight := envutil.GetEnvInt(prefixScorerWeightEnvVar, 1, loggerDebug)
98+
prefixScorer := scorer.NewPrefixAwareScorer(nil)
99+
defaultConfig.scorers[prefixScorer] = prefixScorerWeight // TODO: make configurable
100+
defaultConfig.postResponsePlugins = append(defaultConfig.postResponsePlugins, prefixScorer)
101+
102+
loggerDebug.Info("Initialized PrefixAwareScorer", "weight", prefixScorerWeight)
103+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scorer
18+
19+
import (
20+
"sigs.k8s.io/controller-runtime/pkg/log"
21+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
22+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
23+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
24+
)
25+
26+
const prefixAwareScorerName = "prefix-aware-scorer"
27+
28+
// PrefixAwareScorer is a routing scorer that scores pods based on the longest prefix match
29+
// between the request's prompt and stored prefixes. The score is normalized between 0 and 1,
30+
// where 1 represents the longest matching prefix.
31+
type PrefixAwareScorer struct {
32+
prefixStore *PrefixStore
33+
}
34+
35+
var _ plugins.Scorer = &PrefixAwareScorer{}
36+
37+
// NewPrefixAwareScorer creates a new PrefixAwareScorer with the given
38+
// PrefixStoreConfig. If the config is nil, default is used.
39+
func NewPrefixAwareScorer(config *PrefixStoreConfig) *PrefixAwareScorer {
40+
return &PrefixAwareScorer{
41+
prefixStore: NewPrefixStore(config),
42+
}
43+
}
44+
45+
func (s *PrefixAwareScorer) Name() string {
46+
return "prefix-aware-scorer"
47+
}
48+
49+
// Score scores the target pods based on the longest prefix match.
50+
func (s *PrefixAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
51+
loggerDebug := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG)
52+
if ctx.Req == nil {
53+
loggerDebug.Info("Request is nil, skipping scoring")
54+
return nil
55+
}
56+
57+
scores := s.prefixStore.FindMatchingPods(ctx.Req.Prompt, ctx.Req.Model)
58+
loggerDebug.Info("Got pod scores", "scores", scores)
59+
60+
if len(scores) == 0 {
61+
loggerDebug.Info("No scores found for pods")
62+
return nil
63+
}
64+
65+
podToKey := func(pod types.Pod) (string, bool) {
66+
if pod.GetPod() == nil {
67+
return "", false
68+
}
69+
70+
return pod.GetPod().NamespacedName.String(), true
71+
}
72+
73+
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
74+
}
75+
76+
// PostResponse implements the PostResponsePlugin interface.
77+
// It adds the prefix to the PrefixStore for the given pod.
78+
func (s *PrefixAwareScorer) PostResponse(ctx *types.SchedulingContext, pod types.Pod) {
79+
debugLogger := log.FromContext(ctx).WithName(prefixAwareScorerName).V(logutil.DEBUG)
80+
81+
if ctx.Req == nil {
82+
debugLogger.Info("Request is nil, skipping PostResponse")
83+
return
84+
}
85+
86+
if pod.GetPod() == nil {
87+
debugLogger.Info("Pod is nil, skipping PostResponse", "req", ctx.Req, "pod", pod)
88+
return
89+
}
90+
91+
if err := s.prefixStore.AddEntry(ctx.Req.Model, ctx.Req.Prompt, &pod.GetPod().NamespacedName); err != nil {
92+
debugLogger.Error(err, "Failed to add entry to prefix store", "req", ctx.Req, "pod", pod)
93+
return
94+
}
95+
}
96+
97+
// GetPrefixStore returns the scorer's PrefixStore.
98+
func (s *PrefixAwareScorer) GetPrefixStore() *PrefixStore {
99+
return s.prefixStore
100+
}
101+
102+
// podToKey is a function type that converts a Pod to a string key.
103+
// It returns the key and a boolean indicating success.
104+
type podToKeyFunc func(pod types.Pod) (string, bool)
105+
106+
func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc,
107+
scores map[string]int) map[types.Pod]float64 {
108+
scoredPods := make(map[types.Pod]float64)
109+
minScore, maxScore := getMinMax(scores)
110+
111+
for _, pod := range pods {
112+
key, ok := podToKey(pod)
113+
if !ok {
114+
continue
115+
}
116+
117+
if score, ok := scores[key]; ok {
118+
if minScore == maxScore {
119+
scoredPods[pod] = 1.0
120+
continue
121+
}
122+
123+
scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore)
124+
} else {
125+
scoredPods[pod] = 0.0
126+
}
127+
}
128+
129+
return scoredPods
130+
}
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scorer_test
18+
19+
import (
20+
"context"
21+
k8stypes "k8s.io/apimachinery/pkg/types"
22+
"sigs.k8s.io/controller-runtime/pkg/log"
23+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26+
"testing"
27+
)
28+
29+
func TestPrefixAwareScorer(t *testing.T) {
30+
ctx := context.Background()
31+
logger := log.FromContext(ctx)
32+
ctx = log.IntoContext(ctx, logger)
33+
34+
// Create test pods
35+
pod1 := &types.PodMetrics{
36+
Pod: &backendmetrics.Pod{
37+
NamespacedName: k8stypes.NamespacedName{
38+
Name: "pod1",
39+
Namespace: "default",
40+
},
41+
},
42+
Metrics: &backendmetrics.Metrics{},
43+
}
44+
pod2 := &types.PodMetrics{
45+
Pod: &backendmetrics.Pod{
46+
NamespacedName: k8stypes.NamespacedName{
47+
Name: "pod2",
48+
Namespace: "default",
49+
},
50+
},
51+
Metrics: &backendmetrics.Metrics{},
52+
}
53+
54+
tests := []struct {
55+
name string
56+
weight float64
57+
prompt string
58+
modelName string
59+
prefixToAdd string
60+
podToAdd k8stypes.NamespacedName
61+
prefixModel string // Model name to use when adding the prefix
62+
expectedScores map[types.Pod]float64
63+
}{
64+
{
65+
name: "no prompt",
66+
weight: 1.0,
67+
prompt: "",
68+
modelName: "model1",
69+
prefixToAdd: "hello",
70+
podToAdd: pod1.Pod.NamespacedName,
71+
prefixModel: "model1",
72+
expectedScores: map[types.Pod]float64{}, // No prompt means zero scores
73+
},
74+
{
75+
name: "exact prefix match",
76+
weight: 1.0,
77+
prompt: "hello world",
78+
modelName: "model1",
79+
prefixToAdd: "hello",
80+
podToAdd: pod1.Pod.NamespacedName,
81+
prefixModel: "model1",
82+
expectedScores: map[types.Pod]float64{
83+
pod1: 1.0,
84+
pod2: 0.0,
85+
}, // pod1 matches, pod2 doesn't
86+
},
87+
{
88+
name: "no prefix match",
89+
weight: 1.0,
90+
prompt: "goodbye",
91+
modelName: "model1",
92+
prefixToAdd: "hello",
93+
podToAdd: pod1.Pod.NamespacedName,
94+
prefixModel: "model1",
95+
expectedScores: map[types.Pod]float64{}, // No matching prefix
96+
},
97+
{
98+
name: "different model name",
99+
weight: 1.0,
100+
prompt: "hello world",
101+
modelName: "model2", // Try to find with model2
102+
prefixToAdd: "hello",
103+
podToAdd: pod1.Pod.NamespacedName,
104+
prefixModel: "model1", // But prefix was added with model1
105+
expectedScores: map[types.Pod]float64{}, // Model name mismatch should result in no match
106+
},
107+
{
108+
name: "custom weight",
109+
weight: 0.5,
110+
prompt: "hello world",
111+
modelName: "model1",
112+
prefixToAdd: "hello",
113+
podToAdd: pod1.Pod.NamespacedName,
114+
prefixModel: "model1",
115+
expectedScores: map[types.Pod]float64{
116+
pod1: 0.5, // Pod1 matches with weight
117+
pod2: 0.0, // Pod2 doesn't match
118+
}, // Weight affects score
119+
},
120+
}
121+
122+
for _, tt := range tests {
123+
t.Run(tt.name, func(t *testing.T) {
124+
// Reset prefix store for each test
125+
config := scorer.DefaultPrefixStoreConfig()
126+
config.BlockSize = 5 // set small chunking for testing
127+
128+
s := scorer.NewPrefixAwareScorer(config)
129+
130+
// Add prefix if specified
131+
if tt.prefixToAdd != "" {
132+
err := s.GetPrefixStore().AddEntry(tt.prefixModel,
133+
tt.prefixToAdd, &tt.podToAdd)
134+
if err != nil {
135+
t.Fatalf("Failed to add prefix: %v", err)
136+
}
137+
}
138+
139+
// Create test context
140+
sCtx := types.NewSchedulingContext(ctx, &types.LLMRequest{
141+
Prompt: tt.prompt,
142+
ResolvedTargetModel: tt.modelName,
143+
}, []types.Pod{}, 0)
144+
145+
// Score pods
146+
pods := []types.Pod{pod1, pod2}
147+
scores := s.Score(sCtx, pods)
148+
149+
for p, score := range scores {
150+
if score != tt.expectedScores[p] {
151+
t.Errorf("Pod %v: expected score %v, got %v", p, tt.expectedScores[p], score)
152+
}
153+
}
154+
})
155+
}
156+
}

0 commit comments

Comments
 (0)