Skip to content

Commit d8303a0

Browse files
authored
Merge pull request kubernetes-sigs#1 from mayabar/main
Add scorers support in scheduler
2 parents 09e79e6 + aca8e07 commit d8303a0

File tree

8 files changed

+296
-7
lines changed

8 files changed

+296
-7
lines changed

pkg/epp/datastore/datastore.go

+69-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"errors"
2222
"fmt"
2323
"sync"
24+
"time"
2425

2526
corev1 "k8s.io/api/core/v1"
2627
"k8s.io/apimachinery/pkg/labels"
@@ -34,7 +35,9 @@ import (
3435
)
3536

3637
const (
37-
ModelNameIndexKey = "spec.modelName"
38+
ModelNameIndexKey = "spec.modelName"
39+
sessionKeepAliveTime = 60 * time.Minute // How long should an idle session be kept alive
40+
sessionKeepAliveCheckFrequency = 15 * time.Minute // How often to check for overly idle sessions
3841
)
3942

4043
var (
@@ -65,6 +68,9 @@ type Datastore interface {
6568
PodDelete(namespacedName types.NamespacedName)
6669
PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool)
6770

71+
SetPodForSession(sessionID string, pod *backendmetrics.Pod)
72+
GetPodForSession(sessionID string) *backendmetrics.Pod
73+
6874
// Clears the store state, happens when the pool gets deleted.
6975
Clear()
7076
}
@@ -75,8 +81,12 @@ func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFacto
7581
poolAndModelsMu: sync.RWMutex{},
7682
models: make(map[string]*v1alpha2.InferenceModel),
7783
pods: &sync.Map{},
84+
sessions: &sync.Map{},
7885
pmf: pmf,
7986
}
87+
88+
go store.cleanupSessions(sessionKeepAliveCheckFrequency, sessionKeepAliveTime, parentCtx)
89+
8090
return store
8191
}
8292

@@ -90,7 +100,9 @@ type datastore struct {
90100
models map[string]*v1alpha2.InferenceModel
91101
// key: types.NamespacedName, value: backendmetrics.PodMetrics
92102
pods *sync.Map
93-
pmf *backendmetrics.PodMetricsFactory
103+
// key: session id, value: *backendmetrics.Pod
104+
sessions *sync.Map
105+
pmf *backendmetrics.PodMetricsFactory
94106
}
95107

96108
func (ds *datastore) Clear() {
@@ -291,6 +303,61 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
291303
}
292304
}
293305

306+
type sessionInfo struct {
307+
pod *backendmetrics.Pod
308+
lru time.Time
309+
}
310+
311+
// cleanup Cleans up the set of stored session information by removing information
312+
// of old sessions.
313+
func (ds *datastore) cleanupSessions(keepAliveCheckFrequency time.Duration, sessionKeepAlive time.Duration, ctx context.Context) {
314+
logger := log.FromContext(ctx)
315+
316+
logger.Info("Session-affinity cleanup started")
317+
ticker := time.NewTicker(keepAliveCheckFrequency)
318+
defer ticker.Stop()
319+
320+
for {
321+
select {
322+
case <-ctx.Done():
323+
logger.Info("Session-affinity cleanup stopped:")
324+
return
325+
case now := <-ticker.C:
326+
logger.Info("Session affinity checking")
327+
ds.sessions.Range(
328+
func(sessionID any, rawSessionInfo any) bool {
329+
if sessionInfo, ok := rawSessionInfo.(*sessionInfo); ok {
330+
if now.Sub(sessionInfo.lru) > sessionKeepAlive {
331+
// Session is stale, remove it
332+
ds.sessions.Delete(sessionID)
333+
}
334+
} else {
335+
// Value is not of the correct type, remove it
336+
ds.sessions.Delete(sessionID)
337+
}
338+
return true
339+
})
340+
}
341+
}
342+
}
343+
344+
func (ds *datastore) SetPodForSession(sessionID string, pod *backendmetrics.Pod) {
345+
ds.sessions.Store(sessionID, &sessionInfo{
346+
pod: pod,
347+
lru: time.Now(),
348+
})
349+
}
350+
351+
func (ds *datastore) GetPodForSession(sessionID string) *backendmetrics.Pod {
352+
if value, ok := ds.sessions.Load(sessionID); ok {
353+
if sessionInfo, ok := value.(*sessionInfo); ok {
354+
return sessionInfo.pod
355+
}
356+
}
357+
358+
return nil
359+
}
360+
294361
func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector {
295362
return labels.SelectorFromSet(stripLabelKeyAliasFromLabelMap(selector))
296363
}

pkg/epp/handlers/request.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"strconv"
24+
"strings"
2425
"time"
2526

2627
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
28+
"github.com/google/uuid"
2729
"sigs.k8s.io/controller-runtime/pkg/log"
2830
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
2931
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -62,12 +64,14 @@ func (s *StreamingServer) HandleRequestBody(
6264
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
6365
}
6466
}
67+
6568
llmReq := &schedulingtypes.LLMRequest{
6669
Model: model,
6770
ResolvedTargetModel: modelName,
6871
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
72+
SessionID: reqCtx.SessionID,
6973
}
70-
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical)
74+
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical, "session id", reqCtx.SessionID)
7175

7276
var err error
7377
// Update target models in the body.
@@ -132,6 +136,16 @@ func (s *StreamingServer) HandleRequestBody(
132136
func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
133137
reqCtx.RequestReceivedTimestamp = time.Now()
134138

139+
for _, header := range req.RequestHeaders.Headers.GetHeaders() {
140+
value := string(header.RawValue)
141+
if strings.ToLower(header.Key) == strings.ToLower(SessionIDHeader) && value != "" {
142+
reqCtx.SessionID = value
143+
}
144+
}
145+
if reqCtx.SessionID == "" {
146+
reqCtx.SessionID = uuid.NewString()
147+
}
148+
135149
// an EoS in the request headers means this request has no body or trailers.
136150
if req.RequestHeaders.EndOfStream {
137151
// We will route this request to a random pod as this is assumed to just be a GET

pkg/epp/handlers/server.go

+19
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ type RequestContext struct {
7373
TargetPod string
7474
TargetEndpoint string
7575
Model string
76+
SessionID string
7677
ResolvedTargetModel string
7778
RequestReceivedTimestamp time.Time
7879
ResponseCompleteTimestamp time.Time
@@ -108,6 +109,8 @@ const (
108109
TrailerResponseResponsesComplete StreamRequestState = 7
109110
)
110111

112+
const SessionIDHeader = "Session-ID"
113+
111114
func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
112115
ctx := srv.Context()
113116
logger := log.FromContext(ctx)
@@ -197,6 +200,16 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
197200
loggerTrace.Info("model server is streaming response")
198201
}
199202
}
203+
// Save session is -> pod mapping
204+
allPods := s.datastore.PodGetAll()
205+
206+
for _, pod := range allPods {
207+
if pod.GetPod().NamespacedName.String() == reqCtx.TargetPod {
208+
s.datastore.SetPodForSession(reqCtx.SessionID, pod.GetPod())
209+
break
210+
}
211+
}
212+
200213
reqCtx.RequestState = ResponseRecieved
201214
reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{
202215
Response: &extProcPb.ProcessingResponse_ResponseHeaders{
@@ -211,6 +224,12 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
211224
RawValue: []byte("true"),
212225
},
213226
},
227+
{
228+
Header: &configPb.HeaderValue{
229+
Key: SessionIDHeader,
230+
RawValue: []byte(reqCtx.SessionID),
231+
},
232+
},
214233
},
215234
},
216235
},

pkg/epp/scheduling/scheduler.go

+13-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package scheduling
2020
import (
2121
"context"
2222
"fmt"
23-
"math/rand"
2423

2524
"sigs.k8s.io/controller-runtime/pkg/log"
2625
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
@@ -116,21 +115,27 @@ var (
116115
)
117116

118117
func NewScheduler(datastore Datastore) *Scheduler {
118+
sMng := NewScorerMng()
119+
sMng.addScorer(NewSessionAffinityScorer(1, datastore))
120+
119121
return &Scheduler{
120122
datastore: datastore,
121123
criticalRequestFilter: lowLatencyFilter,
122124
sheddableRequestFilter: sheddableRequestFilter,
125+
scorerMng: sMng,
123126
}
124127
}
125128

126129
type Scheduler struct {
127130
datastore Datastore
128131
criticalRequestFilter Filter
129132
sheddableRequestFilter Filter
133+
scorerMng *ScorerMng
130134
}
131135

132136
type Datastore interface {
133137
PodGetAll() []backendmetrics.PodMetrics
138+
GetPodForSession(SessionID string) *backendmetrics.Pod
134139
}
135140

136141
// Schedule finds the target pod based on metrics and the requested lora adapter.
@@ -154,7 +159,11 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (target
154159
if err != nil || len(pods) == 0 {
155160
return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err)
156161
}
157-
logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods))
158-
i := rand.Intn(len(pods))
159-
return pods[i], nil
162+
163+
selectedPod, err := s.scorerMng.scoreTargets(sCtx, pods)
164+
if err != nil {
165+
return nil, fmt.Errorf("failed to apply scorers: %w", err)
166+
}
167+
168+
return selectedPod, nil
160169
}

pkg/epp/scheduling/scheduler_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,7 @@ func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics {
230230
}
231231
return pm
232232
}
233+
234+
func (fds *fakeDataStore) GetPodForSession(sessionID string) *backendmetrics.Pod {
235+
return nil
236+
}

pkg/epp/scheduling/scorer.go

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scheduling
18+
19+
import (
20+
"fmt"
21+
"math/rand/v2"
22+
23+
"sigs.k8s.io/controller-runtime/pkg/log"
24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
25+
)
26+
27+
type PodScore struct {
28+
Score float64
29+
Pod *types.PodMetrics
30+
}
31+
32+
// Scorer is the interface that scorers must implement
33+
type Scorer interface {
34+
ScoreTargets(ctx *types.Context, pods []*types.PodMetrics) ([]PodScore, error)
35+
}
36+
37+
// Scorer is the interface that scorers must implement
38+
type ScorerMng struct {
39+
scorers []Scorer
40+
}
41+
42+
func NewScorerMng() *ScorerMng {
43+
return &ScorerMng{
44+
scorers: make([]Scorer, 0),
45+
}
46+
}
47+
48+
func (sm *ScorerMng) addScorer(scorer Scorer) {
49+
sm.scorers = append(sm.scorers, scorer)
50+
}
51+
52+
func (sm *ScorerMng) scoreTargets(ctx *types.Context, pods []*types.PodMetrics) (*types.PodMetrics, error) {
53+
logger := log.FromContext(ctx)
54+
55+
podsTotalScore := make(map[*types.PodMetrics]float64)
56+
validPods := make([]*types.PodMetrics, 0)
57+
58+
// initialize zero score for all pods + check that pods are valid
59+
for _, pod := range pods {
60+
if pod == nil || pod.Pod == nil || pod.Metrics == nil {
61+
logger.Info("Invalid/empty pod skipped in scoring process")
62+
} else {
63+
validPods = append(validPods, pod)
64+
podsTotalScore[pod] = 0.0
65+
}
66+
}
67+
68+
if len(validPods) == 0 {
69+
return nil, fmt.Errorf("Empty list of valid pods to score")
70+
}
71+
72+
// add scores from all scorers
73+
for _, scorer := range sm.scorers {
74+
scoredPods, err := scorer.ScoreTargets(ctx, validPods)
75+
if err != nil {
76+
// in case scorer failed - don't use it in the total score, but continue to other scorers
77+
logger.Error(err, "Score targets returned error in scorer")
78+
} else {
79+
for _, scoredPod := range scoredPods {
80+
podsTotalScore[scoredPod.Pod] += scoredPod.Score
81+
}
82+
}
83+
}
84+
85+
// select pod with maximum score, if more than one with the max score - use random pods from the list
86+
var highestScoreTargets []*types.PodMetrics
87+
// score weights cound be negative
88+
maxScore := 0.0
89+
isFirst := true
90+
91+
for pod, score := range podsTotalScore {
92+
if isFirst {
93+
maxScore = score
94+
highestScoreTargets = []*types.PodMetrics{pod}
95+
} else {
96+
if score > maxScore {
97+
maxScore = score
98+
highestScoreTargets = []*types.PodMetrics{pod}
99+
} else if score == maxScore {
100+
highestScoreTargets = append(highestScoreTargets, pod)
101+
}
102+
}
103+
}
104+
105+
// single pod with max score
106+
if len(highestScoreTargets) == 1 {
107+
return highestScoreTargets[0], nil
108+
}
109+
110+
// select random pod from list of pods with max score
111+
return highestScoreTargets[rand.IntN(len(highestScoreTargets))], nil
112+
}

0 commit comments

Comments
 (0)