Skip to content

Add P/D scheduler #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,42 @@ export ENABLE_LOAD_AWARE_SCORER=true
export LOAD_AWARE_SCORER_WEIGHT=1.0
```

To enable PDFilter, the following env var must be configured:
To enable PD Scheduler, the following env var must be configured:
```
export ENABLE_PD_FILTER=true
export PD_ENABLED=true
```

To define prompt length threshold (requests with length is longer than the value defined here will be processed using prefill-decode process), the following env var must be configured:
```
export PD_PROMPT_LEN_THRESHOLD=10
```

Prefill scheduler configuration:

To enable and configure kv cache scorer, the following env vars must be configured:
```
export PREFILL_ENABLE_KVCACHE_AWARE_SCORER=true
export PREFILL_KVCACHE_AWARE_SCORER_WEIGHT=1.0
```

To enable and configure load aware scorer, the following env vars must be configured:
```
export PREFILL_ENABLE_LOAD_AWARE_SCORER=true
export PREFILL_LOAD_AWARE_SCORER_WEIGHT=1.0
```

Decode scheduler configuration:

To enable and configure kv cache scorer, the following env vars must be configured:
```
export DECODE_ENABLE_KVCACHE_AWARE_SCORER=true
export DECODE_KVCACHE_AWARE_SCORER_WEIGHT=1.0
```

To enable and configure load aware scorer, the following env vars must be configured:
```
export DECODE_ENABLE_LOAD_AWARE_SCORER=true
export DECODE_LOAD_AWARE_SCORER_WEIGHT=1.0
```
---
[Inference Gateways]:#concepts-and-definitions
Expand Down
84 changes: 84 additions & 0 deletions pkg/epp/scheduling/config_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scheduling

import (
"context"
"fmt"

"github.com/go-logr/logr"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
)

const (
prefillKvCacheScorerEnablementEnvVar = "PREFILL_ENABLE_KVCACHE_AWARE_SCORER"
prefillLoadAwareScorerEnablementEnvVar = "PREFILL_ENABLE_LOAD_AWARE_SCORER"
decodeKvCacheScorerEnablementEnvVar = "DECODE_ENABLE_KVCACHE_AWARE_SCORER"
decodeLoadAwareScorerEnablementEnvVar = "DECODE_ENABLE_LOAD_AWARE_SCORER"

prefillKvCacheScorerWeightEnvVar = "PREFILL_KVCACHE_AWARE_SCORER_WEIGHT"
prefillLoadAwareScorerWeightEnvVar = "PREFILL_LOAD_AWARE_SCORER_WEIGHT"
decodeKvCacheScorerWeightEnvVar = "DECODE_KVCACHE_AWARE_SCORER_WEIGHT"
decodeLoadAwareScorerWeightEnvVar = "DECODE_LOAD_AWARE_SCORER_WEIGHT"

pdEnabledEnvKey = "PD_ENABLED"

pdPromptLenThresholdEnvKey = "PD_PROMPT_LEN_THRESHOLD"
pdPromptLenThresholdDefault = 10
)

const (
loadAwareScorerName = "LoadAwareScorer"
kvCacheAwareScorerName = "KVCacheAwareScorer"
)

func addScorerByEnvironment(ctx context.Context, config *SchedulerConfig, scorerName string, scorerEnabledEnvKey string, weightEnvKey string, logger logr.Logger) {
if envutil.GetEnvString(scorerEnabledEnvKey, "false", logger) != "true" {
logger.Info(fmt.Sprintf("Skipping %s creation as it is not enabled", scorerName))
return
}

weight := envutil.GetEnvInt(weightEnvKey, 1, logger)
scorer, err := createScorerByName(ctx, scorerName)
if err != nil {
logger.Error(err, "Failed to create scorrer")
return
}

defaultConfig.scorers[scorer] = weight
logger.Info("Initialized scorer", "scorer", scorerName, "weight", weight)
}

func createScorerByName(ctx context.Context, name string) (plugins.Scorer, error) {
switch name {
case loadAwareScorerName:
return &scorer.LoadAwareScorer{}, nil
case kvCacheAwareScorerName:
return scorer.NewKVCacheAwareScorer(ctx)
}
return nil, fmt.Errorf("invalid scorer type %s", name)
}

func getPDEnabledFromEnvironment(logger logr.Logger) bool {
return envutil.GetEnvString(pdEnabledEnvKey, "false", logger) == "true"
}

func getPDPromptLenThresholdFromEnvironment(logger logr.Logger) int {
return envutil.GetEnvInt(pdPromptLenThresholdEnvKey, pdPromptLenThresholdDefault, logger)
}
15 changes: 0 additions & 15 deletions pkg/epp/scheduling/local_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/scorer"
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
Expand All @@ -45,7 +44,6 @@ func setDefaultConfig() {
// this configuration is a temporary state, it should be better streamlined.
setLoadAwareScorer()
setKVCacheAwareScorer()
setPDFilter()

defaultConfig.picker = picker.NewMaxScorePicker()
}
Expand Down Expand Up @@ -83,16 +81,3 @@ func setKVCacheAwareScorer() {
defaultConfig.scorers[kvCacheScorer] = kvCacheScorerWeight
loggerDebug.Info("Initialized KVCacheAwareScorer", "weight", kvCacheScorerWeight)
}

func setPDFilter() {
ctx := context.Background()
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)

if envutil.GetEnvString(pdFilterEnablementEnvVar, "false", loggerDebug) != "true" {
loggerDebug.Info("Skipping PDFilter creation as it is not enabled")
return
}

defaultConfig.filters = append(defaultConfig.filters, filter.PDFilter)
loggerDebug.Info("Initialized PDFilter")
}
78 changes: 78 additions & 0 deletions pkg/epp/scheduling/pd_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package scheduling

import (
"context"

"github.com/go-logr/logr"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

var prefillConfig = &SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{},
filters: []plugins.Filter{filter.PrefillFilter},
scorers: map[plugins.Scorer]int{},
picker: picker.NewMaxScorePicker(),
postSchedulePlugins: []plugins.PostSchedule{},
}
var decodeConfig = &SchedulerConfig{
preSchedulePlugins: []plugins.PreSchedule{},
filters: []plugins.Filter{filter.DecodeFilter},
scorers: map[plugins.Scorer]int{},
picker: picker.NewMaxScorePicker(),
postSchedulePlugins: []plugins.PostSchedule{},
}

var IsPDEnabled = false
var PromptLengthThreshold int

func init() {
ctx := context.Background()
loggerDebug := log.FromContext(ctx).WithName("scheduler_config").V(logutil.DEBUG)

loadPrefillConfiguration(ctx, loggerDebug)
loadDecodeConfiguration(ctx, loggerDebug)

// set IsPDEnabled by environment
IsPDEnabled = getPDEnabledFromEnvironment(loggerDebug)
PromptLengthThreshold = getPDPromptLenThresholdFromEnvironment(loggerDebug)
}

func loadPrefillConfiguration(ctx context.Context, logger logr.Logger) {
// add scorers
addScorerByEnvironment(ctx, prefillConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger)
addScorerByEnvironment(ctx, prefillConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger)

// set filter
// TODO - do we want to keep default filters?
prefillConfig.filters = []plugins.Filter{filter.PrefillFilter}
}

func loadDecodeConfiguration(ctx context.Context, logger logr.Logger) {
// add scorers
addScorerByEnvironment(ctx, decodeConfig, kvCacheAwareScorerName, kvCacheScorerEnablementEnvVar, kvCacheScorerWeightEnvVar, logger)
addScorerByEnvironment(ctx, decodeConfig, loadAwareScorerName, loadAwareScorerEnablementEnvVar, loadAwareScorerWeightEnvVar, logger)

// set filter
// TODO - do we want to keep default filters?
decodeConfig.filters = []plugins.Filter{filter.DecodeFilter}
}
88 changes: 88 additions & 0 deletions pkg/epp/scheduling/pd_scheduler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
Copyright 2025 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package scheduling implements request scheduling algorithms.
package scheduling

import (
"context"
"fmt"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
)

const (
prefillPodHeader = "x-prefiller-url"
)

func NewPDScheduler(datastore Datastore) *PDScheduler {
return NewPDSchedulerWithConfig(datastore, prefillConfig, decodeConfig)
}

func NewPDSchedulerWithConfig(datastore Datastore, prefillConfig *SchedulerConfig, decodeConfig *SchedulerConfig) *PDScheduler {
return &PDScheduler{
datastore: datastore,
prefillScheduler: NewSchedulerWithConfig(datastore, prefillConfig),
decodeScheduler: NewSchedulerWithConfig(datastore, decodeConfig),
}
}

type PDScheduler struct {
datastore Datastore
prefillScheduler *Scheduler
decodeScheduler *Scheduler
}

// Schedule finds the target pod based on metrics and the requested lora adapter.
func (s *PDScheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) {
logger := log.FromContext(ctx).WithValues("pd-schedule", req)

if len(req.Prompt) < PromptLengthThreshold {
// prompt is short enough - use decode scheduling logic
return s.decodeScheduler.Schedule(ctx, req)
}

pool, err := s.datastore.PoolGet()
if err != nil {
return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods
}

// Snapshot pod metrics from the datastore to:
// 1. Reduce concurrent access to the datastore.
// 2. Ensure consistent data during the scheduling operation of a request.
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()), pool.Spec.TargetPortNumber)

// prompt requires processing on two pods - prefill and decode
// start with calculating of the prefill pod
res, err := s.prefillScheduler.scheduleWithContext(ctx, sCtx, req, logger)
if err != nil {
return nil, err
}

if res.TargetPod != nil {
url := fmt.Sprintf("http://%s:%d", res.TargetPod.GetPod().Address, sCtx.TargetPort)
sCtx.MutatedHeaders[prefillPodHeader] = url
}

// get decode pod
return s.decodeScheduler.scheduleWithContext(ctx, sCtx, req, logger)
}

func (s *PDScheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) {
return s.decodeScheduler.RunPostResponsePlugins(ctx, req, targetPodName)
}
Loading