Skip to content

added scheduler plugins interfaces to initialize scheduler refactoring #721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pkg/epp/backend/metrics/pod_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ type podMetrics struct {
ds Datastore
interval time.Duration

parentCtx context.Context
once sync.Once // ensure the StartRefreshLoop is only called once.
done chan struct{}
once sync.Once // ensure the StartRefreshLoop is only called once.
done chan struct{}

logger logr.Logger
}
Expand Down Expand Up @@ -79,8 +78,8 @@ func toInternalPod(in *corev1.Pod) *Pod {
}

// start starts a goroutine exactly once to periodically update metrics. The goroutine will be
// stopped either when stop() is called, or the parentCtx is cancelled.
func (pm *podMetrics) startRefreshLoop() {
// stopped either when stop() is called, or the given ctx is cancelled.
func (pm *podMetrics) startRefreshLoop(ctx context.Context) {
pm.once.Do(func() {
go func() {
pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod())
Expand All @@ -90,7 +89,7 @@ func (pm *podMetrics) startRefreshLoop() {
select {
case <-pm.done:
return
case <-pm.parentCtx.Done():
case <-ctx.Done():
return
case <-ticker.C: // refresh metrics periodically
if err := pm.refreshMetrics(); err != nil {
Expand Down
23 changes: 15 additions & 8 deletions pkg/epp/backend/metrics/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,17 @@ type PodMetricsFactory struct {
func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics {
pod := toInternalPod(in)
pm := &podMetrics{
pmc: f.pmc,
ds: ds,
interval: f.refreshMetricsInterval,
parentCtx: parentCtx,
once: sync.Once{},
done: make(chan struct{}),
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
pmc: f.pmc,
ds: ds,
interval: f.refreshMetricsInterval,
once: sync.Once{},
done: make(chan struct{}),
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
}
pm.pod.Store(pod)
pm.metrics.Store(newMetrics())

pm.startRefreshLoop()
pm.startRefreshLoop(parentCtx)
return pm
}

Expand All @@ -79,6 +78,10 @@ func (p *Pod) String() string {
}

func (p *Pod) Clone() *Pod {
if p == nil {
return nil
}

return &Pod{
NamespacedName: types.NamespacedName{
Name: p.NamespacedName.Name,
Expand Down Expand Up @@ -118,6 +121,10 @@ func (m *Metrics) String() string {
}

func (m *Metrics) Clone() *Metrics {
if m == nil {
return nil
}

cm := make(map[string]int, len(m.ActiveModels))
for k, v := range m.ActiveModels {
cm[k] = v
Expand Down
9 changes: 4 additions & 5 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (s *StreamingServer) HandleRequestBody(
ResolvedTargetModel: modelName,
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
}
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical)
logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq)

var err error
// Update target models in the body.
Expand All @@ -81,11 +81,11 @@ func (s *StreamingServer) HandleRequestBody(
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
}

target, err := s.scheduler.Schedule(ctx, llmReq)
res, err := s.scheduler.Schedule(ctx, llmReq)
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}
targetPod := target.GetPod()
targetPod := res.TargetPod.GetPod()

// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
// Attach the port number
Expand All @@ -96,8 +96,7 @@ func (s *StreamingServer) HandleRequestBody(
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))

logger.V(logutil.DEFAULT).Info("Request handled",
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics",
fmt.Sprintf("%+v", target))
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod)

reqCtx.Model = llmReq.Model
reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel
Expand Down
2 changes: 1 addition & 1 deletion pkg/epp/handlers/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type StreamingServer struct {
}

type Scheduler interface {
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error)
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
}

// RequestContext stores context information during the life time of an HTTP request.
Expand Down
103 changes: 48 additions & 55 deletions pkg/epp/scheduling/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,15 @@ limitations under the License.
package scheduling

import (
"errors"
"math"
"math/rand"
"time"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

type Filter interface {
Name() string
Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error)
}

type basicFilter struct {
name string
filter filterFunc
Expand All @@ -43,7 +38,7 @@ func (bf *basicFilter) Name() string {
return bf.name
}

func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
func (bf *basicFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
loggerTrace := ctx.Logger.V(logutil.TRACE)
loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods))

Expand All @@ -54,19 +49,19 @@ func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*
// depending success or failure of the current filter.
// It can be used to construct a flow chart algorithm.
type decisionTreeFilter struct {
current Filter
current plugins.Filter
// nextOnSuccess filter will be applied after successfully applying the current filter.
// The filtered results will be passed to the next filter.
nextOnSuccess Filter
nextOnSuccess plugins.Filter
// nextOnFailure filter will be applied if current filter fails.
// The original input will be passed to the next filter.
nextOnFailure Filter
nextOnFailure plugins.Filter
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
// success or failure of the current filter.
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
nextOnSuccessOrFailure Filter
nextOnSuccessOrFailure plugins.Filter
}

func (f *decisionTreeFilter) Name() string {
Expand All @@ -76,15 +71,15 @@ func (f *decisionTreeFilter) Name() string {
return f.current.Name()
}

func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
func (f *decisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
loggerTrace := ctx.Logger.V(logutil.TRACE)
filtered, err := f.current.Filter(ctx, pods)
filtered := f.current.Filter(ctx, pods)

next := f.nextOnSuccessOrFailure
if err == nil && len(filtered) > 0 {
if len(filtered) > 0 {
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
// No succeeding filters to run, return.
return filtered, err
return filtered
}
if f.nextOnSuccess != nil {
next = f.nextOnSuccess
Expand All @@ -95,7 +90,7 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics
} else {
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
// No succeeding filters to run, return.
return filtered, err
return filtered
}
if f.nextOnFailure != nil {
next = f.nextOnFailure
Expand All @@ -107,22 +102,20 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics
}

// filterFunc filters a set of input pods to a subset.
type filterFunc func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error)
type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod

// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
func toFilterFunc(pp podPredicate) filterFunc {
return func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
filtered := []*types.PodMetrics{}
return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
filtered := []types.Pod{}
for _, pod := range pods {
pass := pp(ctx.Req, pod)
if pass {
filtered = append(filtered, pod)
}
}
if len(filtered) == 0 {
return nil, errors.New("no pods left")
}
return filtered, nil

return filtered
}
}

Expand All @@ -138,26 +131,26 @@ var leastQueueFilter = &basicFilter{
// the least one as it gives more choices for the next filter, which on aggregate gave better
// results.
// TODO: Compare this strategy with other strategies such as top K.
func leastQueuingFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
min := math.MaxInt
max := 0
filtered := []*types.PodMetrics{}
filtered := []types.Pod{}

for _, pod := range pods {
if pod.WaitingQueueSize <= min {
min = pod.WaitingQueueSize
if pod.GetMetrics().WaitingQueueSize <= min {
min = pod.GetMetrics().WaitingQueueSize
}
if pod.WaitingQueueSize >= max {
max = pod.WaitingQueueSize
if pod.GetMetrics().WaitingQueueSize >= max {
max = pod.GetMetrics().WaitingQueueSize
}
}

for _, pod := range pods {
if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) {
if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) {
filtered = append(filtered, pod)
}
}
return filtered, nil
return filtered
}

var lowQueueFilter = &basicFilter{
Expand All @@ -176,26 +169,26 @@ var leastKVCacheFilter = &basicFilter{
// should consider them all instead of the absolute minimum one. This worked better than picking the
// least one as it gives more choices for the next filter, which on aggregate gave better results.
// TODO: Compare this strategy with other strategies such as top K.
func leastKVCacheFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
min := math.MaxFloat64
var max float64 = 0
filtered := []*types.PodMetrics{}
filtered := []types.Pod{}

for _, pod := range pods {
if pod.KVCacheUsagePercent <= min {
min = pod.KVCacheUsagePercent
if pod.GetMetrics().KVCacheUsagePercent <= min {
min = pod.GetMetrics().KVCacheUsagePercent
}
if pod.KVCacheUsagePercent >= max {
max = pod.KVCacheUsagePercent
if pod.GetMetrics().KVCacheUsagePercent >= max {
max = pod.GetMetrics().KVCacheUsagePercent
}
}

for _, pod := range pods {
if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
filtered = append(filtered, pod)
}
}
return filtered, nil
return filtered
}

var loRAAffinityFilter = &basicFilter{
Expand All @@ -219,20 +212,20 @@ var loRAAffinityFilter = &basicFilter{
// Returns:
// - Filtered slice of pod metrics based on affinity and availability
// - Error if any issues occur during filtering
func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {

// Pre-allocate slices with estimated capacity
filtered_affinity := make([]*types.PodMetrics, 0, len(pods))
filtered_available := make([]*types.PodMetrics, 0, len(pods))
filtered_affinity := make([]types.Pod, 0, len(pods))
filtered_available := make([]types.Pod, 0, len(pods))

// Categorize pods based on affinity and availability
for _, pod := range pods {
_, active := pod.ActiveModels[ctx.Req.ResolvedTargetModel]
_, waiting := pod.WaitingModels[ctx.Req.ResolvedTargetModel]
_, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel]
_, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel]

if active || waiting {
filtered_affinity = append(filtered_affinity, pod)
} else if len(pod.ActiveModels)+len(pod.WaitingModels) < pod.MaxActiveModels {
} else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels {
filtered_available = append(filtered_available, pod)
}
}
Expand All @@ -244,36 +237,36 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([
// If both groups have pods, use probability to select which group to return
if len(filtered_affinity) > 0 && len(filtered_available) > 0 {
if randGen.Float64() < config.LoraAffinityThreshold {
return filtered_affinity, nil
return filtered_affinity
}
return filtered_available, nil
return filtered_available
}

// Return whichever group has pods
if len(filtered_affinity) > 0 {
return filtered_affinity, nil
return filtered_affinity
}

return filtered_available, nil
Copy link
Collaborator

@kfswain kfswain Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh...we never returned an error? I wonder if there is a go linter that checks for this kind of thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly. no use of error. also no use of DropRequestFilter.
the scheduler can just check if len(returned_pods)==0 (up until today scheduler did both err check and that).
that was one of my comments in the original PR:
https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/677/files#r2042685374

return filtered_available
}

// podPredicate is a filter function to check whether a pod is desired.
type podPredicate func(req *types.LLMRequest, pod *types.PodMetrics) bool
type podPredicate func(req *types.LLMRequest, pod types.Pod) bool

func queueThresholdPredicate(queueThreshold int) podPredicate {
return func(req *types.LLMRequest, pod *types.PodMetrics) bool {
return pod.WaitingQueueSize <= queueThreshold
return func(req *types.LLMRequest, pod types.Pod) bool {
return pod.GetMetrics().WaitingQueueSize <= queueThreshold
}
}

func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate {
return func(req *types.LLMRequest, pod *types.PodMetrics) bool {
return pod.KVCacheUsagePercent <= kvCacheThreshold
return func(req *types.LLMRequest, pod types.Pod) bool {
return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold
}
}

func (pp podPredicate) and(another podPredicate) podPredicate {
return func(req *types.LLMRequest, pod *types.PodMetrics) bool {
return func(req *types.LLMRequest, pod types.Pod) bool {
return pp(req, pod) && another(req, pod)
}
}
Loading