Skip to content

Commit 97acfa6

Browse files
committed
Add score weight to XXScorerConfig
1 parent 5278747 commit 97acfa6

File tree

5 files changed

+76
-39
lines changed

5 files changed

+76
-39
lines changed

cmd/epp/main.go

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import (
4545
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics/collectors"
4646
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
4747
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/prefix"
48+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
4849
runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server"
4950
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
5051
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -112,7 +113,8 @@ var (
112113
setupLog = ctrl.Log.WithName("setup")
113114

114115
// Environment variables
115-
schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog)
116+
schedulerV2 = envutil.GetEnvString("EXPERIMENTAL_USE_SCHEDULER_V2", "false", setupLog)
117+
prefixCacheScheduling = envutil.GetEnvString("ENABLE_PREFIX_CACHE_SCHEDULING", "false", setupLog)
116118
)
117119

118120
func loadPrefixCacheConfig() prefix.Config {
@@ -125,16 +127,6 @@ func loadPrefixCacheConfig() prefix.Config {
125127
}
126128
}
127129

128-
func loadSchedulingScorerWeights() scheduling.ScorerWeights {
129-
baseLogger := log.Log.WithName("env-config")
130-
131-
return scheduling.ScorerWeights{
132-
Prefix: envutil.GetEnvInt("PREFIX_CACHE_SCORE_WEIGHT", 3, baseLogger),
133-
Queue: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", 2, baseLogger),
134-
KVCache: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", 1, baseLogger),
135-
}
136-
}
137-
138130
func main() {
139131
if err := run(); err != nil {
140132
os.Exit(1)
@@ -199,9 +191,21 @@ func run() error {
199191

200192
scheduler := scheduling.NewScheduler(datastore)
201193
if schedulerV2 == "true" {
202-
schedConfig := scheduling.CreateConfig(loadSchedulingScorerWeights(), loadPrefixCacheConfig())
203-
setupLog.Info("Creating scheduler", "config", *schedConfig)
204-
scheduler = scheduling.NewSchedulerWithConfig(datastore, schedConfig)
194+
queueConfig := scorer.QueueScorerConfig{
195+
Weight: envutil.GetEnvInt("QUEUE_SCORE_WEIGHT", scorer.DefaultQueueScorerWeight, setupLog),
196+
}
197+
kvCacheConfig := scorer.KVCacheScorerConfig{
198+
Weight: envutil.GetEnvInt("KV_CACHE_SCORE_WEIGHT", scorer.DefaultKVCacheScorerWeight, setupLog),
199+
}
200+
schedConfigOpts := []scheduling.ConfigOption{
201+
scheduling.WithQueuePlugin(queueConfig),
202+
scheduling.WithKVCachePlugin(kvCacheConfig),
203+
}
204+
if prefixCacheScheduling == "true" {
205+
schedConfigOpts = append(schedConfigOpts, scheduling.WithPrefixPlugin(loadPrefixCacheConfig()))
206+
}
207+
schedulerConfig := scheduling.CreateConfig(schedConfigOpts...)
208+
scheduler = scheduling.NewSchedulerWithConfig(datastore, schedulerConfig)
205209
}
206210
serverRunner := &runserver.ExtProcServerRunner{
207211
GrpcPort: *grpcPort,

pkg/epp/scheduling/config_v2.go

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,43 @@ import (
2626
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2727
)
2828

29-
func CreateConfig(weights ScorerWeights, prefixConfig prefix.Config) *SchedulerConfig {
30-
prefixPlugin := prefix.New(prefixConfig)
31-
queuePlugin := &scorer.QueueScorer{}
32-
kvCachePlugin := &scorer.KVCacheScorer{}
29+
func CreateConfig(opts ...ConfigOption) *SchedulerConfig {
3330
config := &SchedulerConfig{
34-
PreSchedulePlugins: []plugins.PreSchedule{prefixPlugin},
35-
PostSchedulePlugins: []plugins.PostSchedule{prefixPlugin},
36-
Scorers: map[plugins.Scorer]int{
37-
prefixPlugin: weights.Prefix,
38-
queuePlugin: weights.Queue,
39-
kvCachePlugin: weights.KVCache,
40-
},
41-
Filters: []plugins.Filter{&sheddableRequestFilterV2{}},
42-
Picker: &picker.MaxScorePicker{},
31+
PreSchedulePlugins: []plugins.PreSchedule{},
32+
PostSchedulePlugins: []plugins.PostSchedule{},
33+
Scorers: map[plugins.Scorer]int{},
34+
Filters: []plugins.Filter{&sheddableRequestFilterV2{}},
35+
Picker: &picker.MaxScorePicker{},
36+
}
37+
for _, opt := range opts {
38+
opt(config)
4339
}
4440
return config
4541
}
4642

47-
type ScorerWeights struct {
48-
Prefix int
49-
Queue int
50-
KVCache int
43+
type ConfigOption func(*SchedulerConfig)
44+
45+
func WithPrefixPlugin(prefixConfig prefix.Config) ConfigOption {
46+
return func(cfg *SchedulerConfig) {
47+
prefixPlugin := prefix.New(prefixConfig)
48+
cfg.PreSchedulePlugins = append(cfg.PreSchedulePlugins, prefixPlugin)
49+
cfg.PostSchedulePlugins = append(cfg.PostSchedulePlugins, prefixPlugin)
50+
cfg.Scorers[prefixPlugin] = prefixConfig.Weight
51+
}
52+
}
53+
54+
func WithQueuePlugin(queueConfig scorer.QueueScorerConfig) ConfigOption {
55+
return func(cfg *SchedulerConfig) {
56+
queuePlugin := &scorer.QueueScorer{}
57+
cfg.Scorers[queuePlugin] = queueConfig.Weight
58+
}
59+
}
60+
61+
func WithKVCachePlugin(kvCacheConfig scorer.KVCacheScorerConfig) ConfigOption {
62+
return func(cfg *SchedulerConfig) {
63+
kvCachePlugin := &scorer.KVCacheScorer{}
64+
cfg.Scorers[kvCachePlugin] = kvCacheConfig.Weight
65+
}
5166
}
5267

5368
type sheddableRequestFilterV2 struct{}

pkg/epp/scheduling/plugins/prefix/plugin.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
)
2828

2929
const (
30+
DefaultScorerWeight = 1
3031
// Attempt to return DefaultNumServersToMatch servers with their longest prefix match length.
3132
// Why not just return the server with longest prefix match?
3233
// It may not be the optimal choice, e.g., it may have a high queue depth.
@@ -49,6 +50,7 @@ const (
4950
)
5051

5152
type Config struct {
53+
Weight int
5254
// The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests
5355
// with length shorter than the block size will be ignored.
5456
HashBlockSize int
@@ -65,7 +67,7 @@ var DefaultConfig = Config{
6567
LRUIndexerCapacity: DefaultLRUIndexerCapacity,
6668
}
6769

68-
type plugin struct {
70+
type Plugin struct {
6971
Config
7072
indexer Indexer
7173
}
@@ -75,34 +77,34 @@ type Indexer interface {
7577
Add(hashes []types.BlockHash, server types.ServerID)
7678
}
7779

78-
func New(config Config) *plugin {
79-
m := &plugin{
80+
func New(config Config) *Plugin {
81+
m := &Plugin{
8082
Config: config,
8183
indexer: newIndexer(config.LRUIndexerCapacity),
8284
}
8385
return m
8486
}
8587

86-
func (m *plugin) Name() string {
88+
func (m *Plugin) Name() string {
8789
return "prefixCache"
8890
}
8991

90-
func (m *plugin) PreSchedule(ctx *types.SchedulingContext) {
92+
func (m *Plugin) PreSchedule(ctx *types.SchedulingContext) {
9193
ctx.PrefixHashes = hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
9294
ctx.PrefixCacheServers = m.matchLongestPrefix(ctx, DefaultNumServersToMatch)
9395
ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("PreSchedule, cached servers: %+v", ctx.PrefixCacheServers), "hashes", ctx.PrefixHashes)
9496
}
9597

9698
// If a request was routed to a server, record it in the cache:
97-
func (m *plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) {
99+
func (m *Plugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) {
98100
targetPod := res.TargetPod.GetPod()
99101
m.indexer.Add(ctx.PrefixHashes, types.ServerID(targetPod.NamespacedName))
100102
total := len(ctx.PrefixHashes)
101103
matchLen := ctx.PrefixCacheServers[types.ServerID(targetPod.NamespacedName)]
102104
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
103105
}
104106

105-
func (m *plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
107+
func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
106108
total := len(ctx.PrefixHashes)
107109
podScoreFunc := func(ctx *types.SchedulingContext, pod types.Pod) float64 {
108110
if total == 0 {
@@ -120,7 +122,7 @@ func (m *plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
120122
}
121123

122124
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
123-
func (m *plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int {
125+
func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, numServers int) map[types.ServerID]int {
124126
if numServers > len(ctx.PodsSnapshot) {
125127
numServers = len(ctx.PodsSnapshot)
126128
}

pkg/epp/scheduling/plugins/scorer/kvcache.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ import (
2020
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2121
)
2222

23+
const (
24+
DefaultKVCacheScorerWeight = 1
25+
)
26+
27+
type KVCacheScorerConfig struct {
28+
Weight int
29+
}
30+
2331
type KVCacheScorer struct{}
2432

2533
func (ss *KVCacheScorer) Name() string {

pkg/epp/scheduling/plugins/scorer/queue.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ import (
2222
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2323
)
2424

25+
const (
26+
DefaultQueueScorerWeight = 1
27+
)
28+
29+
type QueueScorerConfig struct {
30+
Weight int
31+
}
32+
2533
type QueueScorer struct{}
2634

2735
func (q *QueueScorer) Name() string {

0 commit comments

Comments
 (0)