Skip to content

Commit cf82d46

Browse files
committed
added interfaces to support scheduler refactoring
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent 8b9aef6 commit cf82d46

File tree

10 files changed

+238
-184
lines changed

10 files changed

+238
-184
lines changed

pkg/epp/backend/metrics/pod_metrics.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ type podMetrics struct {
4141
ds Datastore
4242
interval time.Duration
4343

44-
parentCtx context.Context
45-
once sync.Once // ensure the StartRefreshLoop is only called once.
46-
done chan struct{}
44+
once sync.Once // ensure the StartRefreshLoop is only called once.
45+
done chan struct{}
4746

4847
logger logr.Logger
4948
}
@@ -80,7 +79,7 @@ func toInternalPod(in *corev1.Pod) *Pod {
8079

8180
// start starts a goroutine exactly once to periodically update metrics. The goroutine will be
8281
// stopped either when stop() is called, or the parentCtx is cancelled.
83-
func (pm *podMetrics) startRefreshLoop() {
82+
func (pm *podMetrics) startRefreshLoop(ctx context.Context) {
8483
pm.once.Do(func() {
8584
go func() {
8685
pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod())
@@ -90,7 +89,7 @@ func (pm *podMetrics) startRefreshLoop() {
9089
select {
9190
case <-pm.done:
9291
return
93-
case <-pm.parentCtx.Done():
92+
case <-ctx.Done():
9493
return
9594
case <-ticker.C: // refresh metrics periodically
9695
if err := pm.refreshMetrics(); err != nil {

pkg/epp/backend/metrics/types.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,17 @@ type PodMetricsFactory struct {
4343
func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics {
4444
pod := toInternalPod(in)
4545
pm := &podMetrics{
46-
pmc: f.pmc,
47-
ds: ds,
48-
interval: f.refreshMetricsInterval,
49-
parentCtx: parentCtx,
50-
once: sync.Once{},
51-
done: make(chan struct{}),
52-
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
46+
pmc: f.pmc,
47+
ds: ds,
48+
interval: f.refreshMetricsInterval,
49+
once: sync.Once{},
50+
done: make(chan struct{}),
51+
logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName),
5352
}
5453
pm.pod.Store(pod)
5554
pm.metrics.Store(newMetrics())
5655

57-
pm.startRefreshLoop()
56+
pm.startRefreshLoop(parentCtx)
5857
return pm
5958
}
6059

@@ -79,6 +78,10 @@ func (p *Pod) String() string {
7978
}
8079

8180
func (p *Pod) Clone() *Pod {
81+
if p == nil {
82+
return nil
83+
}
84+
8285
return &Pod{
8386
NamespacedName: types.NamespacedName{
8487
Name: p.NamespacedName.Name,
@@ -118,6 +121,10 @@ func (m *Metrics) String() string {
118121
}
119122

120123
func (m *Metrics) Clone() *Metrics {
124+
if m == nil {
125+
return nil
126+
}
127+
121128
cm := make(map[string]int, len(m.ActiveModels))
122129
for k, v := range m.ActiveModels {
123130
cm[k] = v

pkg/epp/handlers/request.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func (s *StreamingServer) HandleRequestBody(
6767
ResolvedTargetModel: modelName,
6868
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
6969
}
70-
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical)
70+
logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq)
7171

7272
var err error
7373
// Update target models in the body.
@@ -81,11 +81,11 @@ func (s *StreamingServer) HandleRequestBody(
8181
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
8282
}
8383

84-
target, err := s.scheduler.Schedule(ctx, llmReq)
84+
res, err := s.scheduler.Schedule(ctx, llmReq)
8585
if err != nil {
8686
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
8787
}
88-
targetPod := target.GetPod()
88+
targetPod := res.TargetPod.GetPod()
8989

9090
// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
9191
// Attach the port number
@@ -96,8 +96,7 @@ func (s *StreamingServer) HandleRequestBody(
9696
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
9797

9898
logger.V(logutil.DEFAULT).Info("Request handled",
99-
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics",
100-
fmt.Sprintf("%+v", target))
99+
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod)
101100

102101
reqCtx.Model = llmReq.Model
103102
reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel

pkg/epp/handlers/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ type StreamingServer struct {
6565
}
6666

6767
type Scheduler interface {
68-
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error)
68+
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error)
6969
}
7070

7171
// RequestContext stores context information during the life time of an HTTP request.

pkg/epp/scheduling/filter.go

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,15 @@ limitations under the License.
1717
package scheduling
1818

1919
import (
20-
"errors"
2120
"math"
2221
"math/rand"
2322
"time"
2423

24+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins"
2525
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2626
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2727
)
2828

29-
type Filter interface {
30-
Name() string
31-
Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error)
32-
}
33-
3429
type basicFilter struct {
3530
name string
3631
filter filterFunc
@@ -43,7 +38,7 @@ func (bf *basicFilter) Name() string {
4338
return bf.name
4439
}
4540

46-
func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
41+
func (bf *basicFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
4742
loggerTrace := ctx.Logger.V(logutil.TRACE)
4843
loggerTrace.Info("Running a filter", "name", bf.Name(), "podCount", len(pods))
4944

@@ -54,19 +49,19 @@ func (bf *basicFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*
5449
// depending success or failure of the current filter.
5550
// It can be used to construct a flow chart algorithm.
5651
type decisionTreeFilter struct {
57-
current Filter
52+
current plugins.Filter
5853
// nextOnSuccess filter will be applied after successfully applying the current filter.
5954
// The filtered results will be passed to the next filter.
60-
nextOnSuccess Filter
55+
nextOnSuccess plugins.Filter
6156
// nextOnFailure filter will be applied if current filter fails.
6257
// The original input will be passed to the next filter.
63-
nextOnFailure Filter
58+
nextOnFailure plugins.Filter
6459
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
6560
// success or failure of the current filter.
6661
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
6762
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
6863
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
69-
nextOnSuccessOrFailure Filter
64+
nextOnSuccessOrFailure plugins.Filter
7065
}
7166

7267
func (f *decisionTreeFilter) Name() string {
@@ -76,15 +71,15 @@ func (f *decisionTreeFilter) Name() string {
7671
return f.current.Name()
7772
}
7873

79-
func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
74+
func (f *decisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
8075
loggerTrace := ctx.Logger.V(logutil.TRACE)
81-
filtered, err := f.current.Filter(ctx, pods)
76+
filtered := f.current.Filter(ctx, pods)
8277

8378
next := f.nextOnSuccessOrFailure
84-
if err == nil && len(filtered) > 0 {
79+
if len(filtered) > 0 {
8580
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
8681
// No succeeding filters to run, return.
87-
return filtered, err
82+
return filtered
8883
}
8984
if f.nextOnSuccess != nil {
9085
next = f.nextOnSuccess
@@ -95,7 +90,7 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics
9590
} else {
9691
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
9792
// No succeeding filters to run, return.
98-
return filtered, err
93+
return filtered
9994
}
10095
if f.nextOnFailure != nil {
10196
next = f.nextOnFailure
@@ -107,22 +102,20 @@ func (f *decisionTreeFilter) Filter(ctx *types.Context, pods []*types.PodMetrics
107102
}
108103

109104
// filterFunc filters a set of input pods to a subset.
110-
type filterFunc func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error)
105+
type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod
111106

112107
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
113108
func toFilterFunc(pp podPredicate) filterFunc {
114-
return func(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
115-
filtered := []*types.PodMetrics{}
109+
return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
110+
filtered := []types.Pod{}
116111
for _, pod := range pods {
117112
pass := pp(ctx.Req, pod)
118113
if pass {
119114
filtered = append(filtered, pod)
120115
}
121116
}
122-
if len(filtered) == 0 {
123-
return nil, errors.New("no pods left")
124-
}
125-
return filtered, nil
117+
118+
return filtered
126119
}
127120
}
128121

@@ -138,26 +131,26 @@ var leastQueueFilter = &basicFilter{
138131
// the least one as it gives more choices for the next filter, which on aggregate gave better
139132
// results.
140133
// TODO: Compare this strategy with other strategies such as top K.
141-
func leastQueuingFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
134+
func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
142135
min := math.MaxInt
143136
max := 0
144-
filtered := []*types.PodMetrics{}
137+
filtered := []types.Pod{}
145138

146139
for _, pod := range pods {
147-
if pod.WaitingQueueSize <= min {
148-
min = pod.WaitingQueueSize
140+
if pod.GetMetrics().WaitingQueueSize <= min {
141+
min = pod.GetMetrics().WaitingQueueSize
149142
}
150-
if pod.WaitingQueueSize >= max {
151-
max = pod.WaitingQueueSize
143+
if pod.GetMetrics().WaitingQueueSize >= max {
144+
max = pod.GetMetrics().WaitingQueueSize
152145
}
153146
}
154147

155148
for _, pod := range pods {
156-
if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) {
149+
if pod.GetMetrics().WaitingQueueSize >= min && pod.GetMetrics().WaitingQueueSize <= min+(max-min)/len(pods) {
157150
filtered = append(filtered, pod)
158151
}
159152
}
160-
return filtered, nil
153+
return filtered
161154
}
162155

163156
var lowQueueFilter = &basicFilter{
@@ -176,26 +169,26 @@ var leastKVCacheFilter = &basicFilter{
176169
// should consider them all instead of the absolute minimum one. This worked better than picking the
177170
// least one as it gives more choices for the next filter, which on aggregate gave better results.
178171
// TODO: Compare this strategy with other strategies such as top K.
179-
func leastKVCacheFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
172+
func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
180173
min := math.MaxFloat64
181174
var max float64 = 0
182-
filtered := []*types.PodMetrics{}
175+
filtered := []types.Pod{}
183176

184177
for _, pod := range pods {
185-
if pod.KVCacheUsagePercent <= min {
186-
min = pod.KVCacheUsagePercent
178+
if pod.GetMetrics().KVCacheUsagePercent <= min {
179+
min = pod.GetMetrics().KVCacheUsagePercent
187180
}
188-
if pod.KVCacheUsagePercent >= max {
189-
max = pod.KVCacheUsagePercent
181+
if pod.GetMetrics().KVCacheUsagePercent >= max {
182+
max = pod.GetMetrics().KVCacheUsagePercent
190183
}
191184
}
192185

193186
for _, pod := range pods {
194-
if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
187+
if pod.GetMetrics().KVCacheUsagePercent >= min && pod.GetMetrics().KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
195188
filtered = append(filtered, pod)
196189
}
197190
}
198-
return filtered, nil
191+
return filtered
199192
}
200193

201194
var loRAAffinityFilter = &basicFilter{
@@ -219,20 +212,20 @@ var loRAAffinityFilter = &basicFilter{
219212
// Returns:
220213
// - Filtered slice of pod metrics based on affinity and availability
221214
// - Error if any issues occur during filtering
222-
func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([]*types.PodMetrics, error) {
215+
func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
223216

224217
// Pre-allocate slices with estimated capacity
225-
filtered_affinity := make([]*types.PodMetrics, 0, len(pods))
226-
filtered_available := make([]*types.PodMetrics, 0, len(pods))
218+
filtered_affinity := make([]types.Pod, 0, len(pods))
219+
filtered_available := make([]types.Pod, 0, len(pods))
227220

228221
// Categorize pods based on affinity and availability
229222
for _, pod := range pods {
230-
_, active := pod.ActiveModels[ctx.Req.ResolvedTargetModel]
231-
_, waiting := pod.WaitingModels[ctx.Req.ResolvedTargetModel]
223+
_, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel]
224+
_, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel]
232225

233226
if active || waiting {
234227
filtered_affinity = append(filtered_affinity, pod)
235-
} else if len(pod.ActiveModels)+len(pod.WaitingModels) < pod.MaxActiveModels {
228+
} else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels {
236229
filtered_available = append(filtered_available, pod)
237230
}
238231
}
@@ -244,36 +237,36 @@ func loRASoftAffinityFilterFunc(ctx *types.Context, pods []*types.PodMetrics) ([
244237
// If both groups have pods, use probability to select which group to return
245238
if len(filtered_affinity) > 0 && len(filtered_available) > 0 {
246239
if randGen.Float64() < config.LoraAffinityThreshold {
247-
return filtered_affinity, nil
240+
return filtered_affinity
248241
}
249-
return filtered_available, nil
242+
return filtered_available
250243
}
251244

252245
// Return whichever group has pods
253246
if len(filtered_affinity) > 0 {
254-
return filtered_affinity, nil
247+
return filtered_affinity
255248
}
256249

257-
return filtered_available, nil
250+
return filtered_available
258251
}
259252

260253
// podPredicate is a filter function to check whether a pod is desired.
261-
type podPredicate func(req *types.LLMRequest, pod *types.PodMetrics) bool
254+
type podPredicate func(req *types.LLMRequest, pod types.Pod) bool
262255

263256
func queueThresholdPredicate(queueThreshold int) podPredicate {
264-
return func(req *types.LLMRequest, pod *types.PodMetrics) bool {
265-
return pod.WaitingQueueSize <= queueThreshold
257+
return func(req *types.LLMRequest, pod types.Pod) bool {
258+
return pod.GetMetrics().WaitingQueueSize <= queueThreshold
266259
}
267260
}
268261

269262
func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate {
270-
return func(req *types.LLMRequest, pod *types.PodMetrics) bool {
271-
return pod.KVCacheUsagePercent <= kvCacheThreshold
263+
return func(req *types.LLMRequest, pod types.Pod) bool {
264+
return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold
272265
}
273266
}
274267

275268
func (pp podPredicate) and(another podPredicate) podPredicate {
276-
return func(req *types.LLMRequest, pod *types.PodMetrics) bool {
269+
return func(req *types.LLMRequest, pod types.Pod) bool {
277270
return pp(req, pod) && another(req, pod)
278271
}
279272
}

0 commit comments

Comments
 (0)