Skip to content

Commit 47ea097

Browse files
mayabarclubanderson
authored andcommitted
Add P/D scheduler (#115)
* Add P/D scheduler - use 2 schedulers in it, one for prefill and one for decode. P/D scheduler is enabled by environment variable value, list of scorers and their weight are defined by environment variables + delete pd-filter * Remove unused variable * Update readme file with envirnment variables relevant to P/D scheduler * Fix problem caused by merge * Add documentation for PDScheduler.Schedule function * Update names of prefill and decode filters to avoid spaces * Update comment for prefill/decode fitlers * Change IsPDEnabled to PDEnabled * Fix typo in readme * Fix pd scheduler behavior for short promprts * Fix prefill/decode related text in readme * Remove redundant filter creation of prefil/decode filters + make promptLengthThreshold local Add function for schedulerContext creation * Fixes in readme * fix compilation prblem * add pd scheduler test * add postResponse plugins array to prefile and decode config * fix comment in test * fix pd-scheduler test
1 parent 5c06734 commit 47ea097

File tree

9 files changed

+488
-67
lines changed

9 files changed

+488
-67
lines changed

README.md

+37-4
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,56 @@ This project offers tools for AI Inference, enabling developers to build [Infere
99
---
1010
## Temporary Fork Configuration
1111

12-
To enable KVCacheAwareScorer, the following env vars must be configured:
12+
To enable the KVCacheAwareScorer, the following environment variables must be configured:
1313
```
1414
export ENABLE_KVCACHE_AWARE_SCORER=true
1515
export KVCACHE_AWARE_SCORER_WEIGHT=1.0
1616
export KVCACHE_INDEXER_REDIS_ADDR=<redis-service>
1717
export HF_TOKEN=<HuggingFace Token that has access to the vLLM models>
1818
```
1919

20-
To enable LoadAwareScorer, the following env vars must be configured:
20+
To enable the LoadAwareScorer, the following environment variables must be configured:
2121
```
2222
export ENABLE_LOAD_AWARE_SCORER=true
2323
export LOAD_AWARE_SCORER_WEIGHT=1.0
2424
```
2525

26-
To enable PDFilter, the following env var must be configured:
26+
To enable Prefill/Decode (PD) processing, the following environment variable must be configured:
2727
```
28-
export ENABLE_PD_FILTER=true
28+
export PD_ENABLED=true
29+
```
30+
31+
To define the prompt length threshold (requests with a prompt longer than the value defined here will be processed using the prefill-decode process), the following environment variable must be configured:
32+
```
33+
export PD_PROMPT_LEN_THRESHOLD=10
34+
```
35+
36+
Prefill configuration:
37+
38+
To enable and configure the kv cache scorer for prefill, the following environment variables must be configured:
39+
```
40+
export PREFILL_ENABLE_KVCACHE_AWARE_SCORER=true
41+
export PREFILL_KVCACHE_AWARE_SCORER_WEIGHT=1.0
42+
```
43+
44+
To enable and configure the load aware scorer for prefill, the following environment variables must be configured:
45+
```
46+
export PREFILL_ENABLE_LOAD_AWARE_SCORER=true
47+
export PREFILL_LOAD_AWARE_SCORER_WEIGHT=1.0
48+
```
49+
50+
Decode configuration:
51+
52+
To enable and configure the kv cache scorer for decode, the following environment variables must be configured:
53+
```
54+
export DECODE_ENABLE_KVCACHE_AWARE_SCORER=true
55+
export DECODE_KVCACHE_AWARE_SCORER_WEIGHT=1.0
56+
```
57+
58+
To enable and configure the load aware scorer for decode, the following environment variables must be configured:
59+
```
60+
export DECODE_ENABLE_LOAD_AWARE_SCORER=true
61+
export DECODE_LOAD_AWARE_SCORER_WEIGHT=1.0
2962
```
3063
---
3164
[Inference Gateways]:#concepts-and-definitions

pkg/epp/scheduling/config_utils.go

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scheduling
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"github.com/go-logr/logr"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
26+
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
27+
)
28+
29+
const (
30+
prefillKvCacheScorerEnablementEnvVar = "PREFILL_ENABLE_KVCACHE_AWARE_SCORER"
31+
prefillLoadAwareScorerEnablementEnvVar = "PREFILL_ENABLE_LOAD_AWARE_SCORER"
32+
decodeKvCacheScorerEnablementEnvVar = "DECODE_ENABLE_KVCACHE_AWARE_SCORER"
33+
decodeLoadAwareScorerEnablementEnvVar = "DECODE_ENABLE_LOAD_AWARE_SCORER"
34+
35+
prefillKvCacheScorerWeightEnvVar = "PREFILL_KVCACHE_AWARE_SCORER_WEIGHT"
36+
prefillLoadAwareScorerWeightEnvVar = "PREFILL_LOAD_AWARE_SCORER_WEIGHT"
37+
decodeKvCacheScorerWeightEnvVar = "DECODE_KVCACHE_AWARE_SCORER_WEIGHT"
38+
decodeLoadAwareScorerWeightEnvVar = "DECODE_LOAD_AWARE_SCORER_WEIGHT"
39+
40+
pdEnabledEnvKey = "PD_ENABLED"
41+
42+
pdPromptLenThresholdEnvKey = "PD_PROMPT_LEN_THRESHOLD"
43+
pdPromptLenThresholdDefault = 10
44+
)
45+
46+
const (
47+
loadAwareScorerName = "LoadAwareScorer"
48+
kvCacheAwareScorerName = "KVCacheAwareScorer"
49+
)
50+
51+
func addScorerByEnvironment(ctx context.Context, config *SchedulerConfig, scorerName string, scorerEnabledEnvKey string, weightEnvKey string, logger logr.Logger) {
52+
if envutil.GetEnvString(scorerEnabledEnvKey, "false", logger) != "true" {
53+
logger.Info(fmt.Sprintf("Skipping %s creation as it is not enabled", scorerName))
54+
return
55+
}
56+
57+
weight := envutil.GetEnvInt(weightEnvKey, 1, logger)
58+
scorer, err := createScorerByName(ctx, scorerName)
59+
if err != nil {
60+
logger.Error(err, "Failed to create scorrer")
61+
return
62+
}
63+
64+
defaultConfig.scorers[scorer] = weight
65+
logger.Info("Initialized scorer", "scorer", scorerName, "weight", weight)
66+
}
67+
68+
func createScorerByName(ctx context.Context, name string) (plugins.Scorer, error) {
69+
switch name {
70+
case loadAwareScorerName:
71+
return &scorer.LoadAwareScorer{}, nil
72+
case kvCacheAwareScorerName:
73+
return scorer.NewKVCacheAwareScorer(ctx)
74+
}
75+
return nil, fmt.Errorf("invalid scorer type %s", name)
76+
}
77+
78+
func getPDEnabledFromEnvironment(logger logr.Logger) bool {
79+
return envutil.GetEnvString(pdEnabledEnvKey, "false", logger) == "true"
80+
}
81+
82+
func getPDPromptLenThresholdFromEnvironment(logger logr.Logger) int {
83+
return envutil.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger)
84+
}

pkg/epp/scheduling/local_config.go

-15
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121

2222
"sigs.k8s.io/controller-runtime/pkg/log"
23-
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
2423
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
2524
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
2625
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
@@ -45,7 +44,6 @@ func setDefaultConfig() {
4544
// this configuration is a temporary state, it should be better streamlined.
4645
setLoadAwareScorer()
4746
setKVCacheAwareScorer()
48-
setPDFilter()
4947

5048
defaultConfig.picker = picker.NewMaxScorePicker()
5149
}
@@ -83,16 +81,3 @@ func setKVCacheAwareScorer() {
8381
defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight
8482
loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight)
8583
}
86-
87-
func setPDFilter() {
88-
ctx := context.Background()
89-
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)
90-
91-
if envutil.GetEnvString(pdFilterEnablementEnvVar, "false", loggerDebug) != "true" {
92-
loggerDebug.Info("Skipping PDFilter creation as it is not enabled")
93-
return
94-
}
95-
96-
defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter)
97-
loggerDebug.Info("Initialized PDFilter")
98-
}

pkg/epp/scheduling/pd_config.go

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scheduling
18+
19+
import (
20+
"context"
21+
22+
"github.com/go-logr/logr"
23+
"sigs.k8s.io/controller-runtime/pkg/log"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
27+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
28+
)
29+
30+
var prefillConfig = &SchedulerConfig{
31+
preSchedulePlugins: []plugins.PreSchedule{},
32+
filters: []plugins.Filter{filter.PrefillFilter},
33+
scorers: map[plugins.Scorer]int{},
34+
picker: picker.NewMaxScorePicker(),
35+
postSchedulePlugins: []plugins.PostSchedule{},
36+
postResponsePlugins: []plugins.PostResponse{},
37+
}
38+
var decodeConfig = &SchedulerConfig{
39+
preSchedulePlugins: []plugins.PreSchedule{},
40+
filters: []plugins.Filter{filter.DecodeFilter},
41+
scorers: map[plugins.Scorer]int{},
42+
picker: picker.NewMaxScorePicker(),
43+
postSchedulePlugins: []plugins.PostSchedule{},
44+
postResponsePlugins: []plugins.PostResponse{},
45+
}
46+
47+
var PDEnabled = false
48+
var promptLengthThreshold int
49+
50+
func init() {
51+
ctx := context.Background()
52+
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)
53+
54+
loadPrefillConfiguration(ctx, loggerDebug)
55+
loadDecodeConfiguration(ctx, loggerDebug)
56+
57+
// set IsPDEnabled by environment
58+
PDEnabled = getPDEnabledFromEnvironment(loggerDebug)
59+
promptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug)
60+
}
61+
62+
func loadPrefillConfiguration(ctx context.Context, logger logr.Logger) {
63+
// add scorers
64+
addScorerByEnvironment(ctx, prefillConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger)
65+
addScorerByEnvironment(ctx, prefillConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger)
66+
}
67+
68+
func loadDecodeConfiguration(ctx context.Context, logger logr.Logger) {
69+
// add scorers
70+
addScorerByEnvironment(ctx, decodeConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger)
71+
addScorerByEnvironment(ctx, decodeConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger)
72+
}

pkg/epp/scheduling/pd_scheduler.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
// Package scheduling implements request scheduling algorithms.
18+
package scheduling
19+
20+
import (
21+
"context"
22+
"fmt"
23+
24+
"sigs.k8s.io/controller-runtime/pkg/log"
25+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26+
)
27+
28+
const (
29+
prefillPodHeader = "x-prefiller-url"
30+
)
31+
32+
func NewPDScheduler(datastore Datastore) *PDScheduler {
33+
return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig, defaultConfig)
34+
}
35+
36+
func NewPDSchedulerWithConfig(datastore Datastore, pConfig *SchedulerConfig, dConfig *SchedulerConfig, defConfig *SchedulerConfig) *PDScheduler {
37+
return &PDScheduler{
38+
datastore: datastore,
39+
prefillScheduler: NewSchedulerWithConfig(datastore, pConfig),
40+
decodeScheduler: NewSchedulerWithConfig(datastore, dConfig),
41+
defaultScheduler: NewSchedulerWithConfig(datastore, defConfig),
42+
}
43+
}
44+
45+
type PDScheduler struct {
46+
datastore Datastore
47+
prefillScheduler *Scheduler
48+
decodeScheduler *Scheduler
49+
defaultScheduler *Scheduler
50+
}
51+
52+
// Schedule finds the target pod based on metrics and the requested lora adapter.
53+
// PD scheduler uses three base schedulers to process requests, the overall configuration is currently loaded from environment variables.
54+
// If the request prompt is short enough (defined by the threshold in the configuration) - use the default behavior
55+
// If the request prompt is long enough to use prefill-decode process:
56+
// 1 - find the pod for prefill, save its url in a special header. For this, use the Scheduler configured for this goal, which uses the prefill filter
57+
// and scorers according to the configuration.
58+
// 2 - find the pod for decode, use the Scheduler configured for this goal, which uses the decode filer and scorers defined in the configuration
59+
func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) {
60+
logger := log.FromContext(ctx).WithValues("pd-schedule", req)
61+
62+
if len(req.Prompt) < promptLengthThreshold {
63+
// the prompt is short enough - use the default scheduling logic
64+
return s.defaultScheduler.Schedule(ctx, req)
65+
}
66+
67+
sCtx, err := createSchedulerContext(ctx, req, s.datastore)
68+
if err != nil {
69+
return nil, err
70+
}
71+
72+
// prompt requires processing on two pods - prefill and decode
73+
// start with calculating of the prefill pod
74+
res, err := s.prefillScheduler.scheduleWithContext(ctx, sCtx, req, logger)
75+
if err != nil {
76+
return nil, err
77+
}
78+
79+
if res.TargetPod != nil {
80+
url := fmt.Sprintf("http://%s:%d", res.TargetPod.GetPod().Address, sCtx.TargetPort)
81+
sCtx.MutatedHeaders[prefillPodHeader] = url
82+
}
83+
84+
// get decode pod
85+
return s.decodeScheduler.scheduleWithContext(ctx, sCtx, req, logger)
86+
}
87+
88+
func (s *PDScheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) {
89+
return s.decodeScheduler.RunPostResponsePlugins(ctx, req, targetPodName)
90+
}

0 commit comments

Comments
 (0)