Skip to content

Add P/D scheduler #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,56 @@ This project offers tools for AI Inference, enabling developers to build [Infere
---
## Temporary Fork Configuration

To enable KVCacheAwareScorer, the following env vars must be configured:
To enable the KVCacheAwareScorer, the following environment variables must be configured:
```
export ENABLE_KVCACHE_AWARE_SCORER=true
export KVCACHE_AWARE_SCORER_WEIGHT=1.0
export KVCACHE_INDEXER_REDIS_ADDR=<redis-service>
export HF_TOKEN=<HuggingFace Token that has access to the vLLM models>
```

To enable LoadAwareScorer, the following env vars must be configured:
To enable the LoadAwareScorer, the following environment variables must be configured:
```
export ENABLE_LOAD_AWARE_SCORER=true
export LOAD_AWARE_SCORER_WEIGHT=1.0
```

To enable PDFilter, the following env var must be configured:
To enable Prefill/Decode (PD) processing, the following environment variable must be configured:
```
export ENABLE_PD_FILTER=true
export PD_ENABLED=true
```

To define the prompt length threshold (requests with a prompt longer than the value defined here will be processed using the prefill-decode process), the following environment variable must be configured:
```
export PD_PROMPT_LEN_THRESHOLD=10
```

Prefill configuration:

To enable and configure the kv cache scorer for prefill, the following environment variables must be configured:
```
export PREFILL_ENABLE_KVCACHE_AWARE_SCORER=true
export PREFILL_KVCACHE_AWARE_SCORER_WEIGHT=1.0
```

To enable and configure the load aware scorer for prefill, the following environment variables must be configured:
```
export PREFILL_ENABLE_LOAD_AWARE_SCORER=true
export PREFILL_LOAD_AWARE_SCORER_WEIGHT=1.0
```

Decode configuration:

To enable and configure the kv cache scorer for decode, the following environment variables must be configured:
```
export DECODE_ENABLE_KVCACHE_AWARE_SCORER=true
export DECODE_KVCACHE_AWARE_SCORER_WEIGHT=1.0
```

To enable and configure the load aware scorer for decode, the following environment variables must be configured:
```
export DECODE_ENABLE_LOAD_AWARE_SCORER=true
export DECODE_LOAD_AWARE_SCORER_WEIGHT=1.0
```
---
[Inference Gateways]:#concepts-and-definitions
Expand Down
84 changes: 84 additions & 0 deletions pkg/epp/scheduling/config_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scheduling

import (
"context"
"fmt"

"github.com/go-logr/logr"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
)

const (
prefillKvCacheScorerEnablementEnvVar = "PREFILL_ENABLE_KVCACHE_AWARE_SCORER"
prefillLoadAwareScorerEnablementEnvVar = "PREFILL_ENABLE_LOAD_AWARE_SCORER"
decodeKvCacheScorerEnablementEnvVar = "DECODE_ENABLE_KVCACHE_AWARE_SCORER"
decodeLoadAwareScorerEnablementEnvVar = "DECODE_ENABLE_LOAD_AWARE_SCORER"

prefillKvCacheScorerWeightEnvVar = "PREFILL_KVCACHE_AWARE_SCORER_WEIGHT"
prefillLoadAwareScorerWeightEnvVar = "PREFILL_LOAD_AWARE_SCORER_WEIGHT"
decodeKvCacheScorerWeightEnvVar = "DECODE_KVCACHE_AWARE_SCORER_WEIGHT"
decodeLoadAwareScorerWeightEnvVar = "DECODE_LOAD_AWARE_SCORER_WEIGHT"

pdEnabledEnvKey = "PD_ENABLED"

pdPromptLenThresholdEnvKey = "PD_PROMPT_LEN_THRESHOLD"
pdPromptLenThresholdDefault = 10
)

const (
loadAwareScorerName = "LoadAwareScorer"
kvCacheAwareScorerName = "KVCacheAwareScorer"
)

func addScorerByEnvironment(ctx context.Context, config *SchedulerConfig, scorerName string, scorerEnabledEnvKey string, weightEnvKey string, logger logr.Logger) {
if envutil.GetEnvString(scorerEnabledEnvKey, "false", logger) != "true" {
logger.Info(fmt.Sprintf("Skipping %s creation as it is not enabled", scorerName))
return
}

weight := envutil.GetEnvInt(weightEnvKey, 1, logger)
scorer, err := createScorerByName(ctx, scorerName)
if err != nil {
logger.Error(err, "Failed to create scorrer")
return
}

defaultConfig.scorers[scorer] = weight
logger.Info("Initialized scorer", "scorer", scorerName, "weight", weight)
}

func createScorerByName(ctx context.Context, name string) (plugins.Scorer, error) {
switch name {
case loadAwareScorerName:
return &scorer.LoadAwareScorer{}, nil
case kvCacheAwareScorerName:
return scorer.NewKVCacheAwareScorer(ctx)
}
return nil, fmt.Errorf("invalid scorer type %s", name)
}

func getPDEnabledFromEnvironment(logger logr.Logger) bool {
return envutil.GetEnvString(pdEnabledEnvKey, "false", logger) == "true"
}

func getPDPromptLenThresholdFromEnvironment(logger logr.Logger) int {
return envutil.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger)
}
15 changes: 0 additions & 15 deletions pkg/epp/scheduling/local_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
Expand All @@ -45,7 +44,6 @@ func setDefaultConfig() {
// this configuration is a temporary state, it should be better streamlined.
setLoadAwareScorer()
setKVCacheAwareScorer()
setPDFilter()

defaultConfig.picker = picker.NewMaxScorePicker()
}
Expand Down Expand Up @@ -83,16 +81,3 @@ func setKVCacheAwareScorer() {
defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight
loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight)
}

func setPDFilter() {
ctx := context.Background()
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)

if envutil.GetEnvString(pdFilterEnablementEnvVar, "false", loggerDebug) != "true" {
loggerDebug.Info("Skipping PDFilter creation as it is not enabled")
return
}

defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter)
loggerDebug.Info("Initialized PDFilter")
}
72 changes: 72 additions & 0 deletions pkg/epp/scheduling/pd_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scheduling

import (
"context"

"github.com/go-logr/logr"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

var prefillConfig = &SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{},
filters: []plugins.Filter{filter.PrefillFilter},
scorers: map[plugins.Scorer]int{},
picker: picker.NewMaxScorePicker(),
postSchedulePlugins: []plugins.PostSchedule{},
postResponsePlugins: []plugins.PostResponse{},
}
var decodeConfig = &SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{},
filters: []plugins.Filter{filter.DecodeFilter},
scorers: map[plugins.Scorer]int{},
picker: picker.NewMaxScorePicker(),
postSchedulePlugins: []plugins.PostSchedule{},
postResponsePlugins: []plugins.PostResponse{},
}

var PDEnabled = false
var promptLengthThreshold int

func init() {
ctx := context.Background()
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)

loadPrefillConfiguration(ctx, loggerDebug)
loadDecodeConfiguration(ctx, loggerDebug)

// set IsPDEnabled by environment
PDEnabled = getPDEnabledFromEnvironment(loggerDebug)
promptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug)
}

func loadPrefillConfiguration(ctx context.Context, logger logr.Logger) {
// add scorers
addScorerByEnvironment(ctx, prefillConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger)
addScorerByEnvironment(ctx, prefillConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger)
}

func loadDecodeConfiguration(ctx context.Context, logger logr.Logger) {
// add scorers
addScorerByEnvironment(ctx, decodeConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger)
addScorerByEnvironment(ctx, decodeConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger)
}
90 changes: 90 additions & 0 deletions pkg/epp/scheduling/pd_scheduler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package scheduling implements request scheduling algorithms.
package scheduling

import (
"context"
"fmt"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

const (
prefillPodHeader = "x-prefiller-url"
)

func NewPDScheduler(datastore Datastore) *PDScheduler {
return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig, defaultConfig)
}

func NewPDSchedulerWithConfig(datastore Datastore, pConfig *SchedulerConfig, dConfig *SchedulerConfig, defConfig *SchedulerConfig) *PDScheduler {
return &PDScheduler{
datastore: datastore,
prefillScheduler: NewSchedulerWithConfig(datastore, pConfig),
decodeScheduler: NewSchedulerWithConfig(datastore, dConfig),
defaultScheduler: NewSchedulerWithConfig(datastore, defConfig),
}
}

type PDScheduler struct {
datastore Datastore
prefillScheduler *Scheduler
decodeScheduler *Scheduler
defaultScheduler *Scheduler
}

// Schedule finds the target pod based on metrics and the requested lora adapter.
// PD scheduler uses three base schedulers to process requests, the overall configuration is currently loaded from environment variables.
// If the request prompt is short enough (defined by the threshold in the configuration) - use the default behavior
// If the request prompt is long enough to use prefill-decode process:
// 1 - find the pod for prefill, save its url in a special header. For this, use the Scheduler configured for this goal, which uses the prefill filter
// and scorers according to the configuration.
// 2 - find the pod for decode, use the Scheduler configured for this goal, which uses the decode filer and scorers defined in the configuration
func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) {
logger := log.FromContext(ctx).WithValues("pd-schedule", req)

if len(req.Prompt) < promptLengthThreshold {
// the prompt is short enough - use the default scheduling logic
return s.defaultScheduler.Schedule(ctx, req)
}

sCtx, err := createSchedulerContext(ctx, req, s.datastore)
if err != nil {
return nil, err
}

// prompt requires processing on two pods - prefill and decode
// start with calculating of the prefill pod
res, err := s.prefillScheduler.scheduleWithContext(ctx, sCtx, req, logger)
if err != nil {
return nil, err
}

if res.TargetPod != nil {
url := fmt.Sprintf("http://%s:%d", res.TargetPod.GetPod().Address, sCtx.TargetPort)
sCtx.MutatedHeaders[prefillPodHeader] = url
}

// get decode pod
return s.decodeScheduler.scheduleWithContext(ctx, sCtx, req, logger)
}

func (s *PDScheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) {
return s.decodeScheduler.RunPostResponsePlugins(ctx, req, targetPodName)
}
Loading