Skip to content

Commit 18148a9

Browse files
authored
Merge pull request kubernetes-sigs#138 from shmuelk/llm-response
[UPSTREAM-SYNC] Add LLMResponse object and RequestId to LLMRequest
2 parents fd1ddfa + 67bfe89 commit 18148a9

File tree

7 files changed

+42
-8
lines changed

7 files changed

+42
-8
lines changed

pkg/epp/requestcontrol/director.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3232
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
3333
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
34+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3435
)
3536

3637
type Scheduler interface {
@@ -82,6 +83,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
8283
}
8384

8485
llmReq := &schedulingtypes.LLMRequest{
86+
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
8587
Model: reqCtx.Model,
8688
ResolvedTargetModel: reqCtx.ResolvedTargetModel,
8789
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,

pkg/epp/scheduling/plugins/filter/filter_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"testing"
2222

2323
"github.com/google/go-cmp/cmp"
24+
"github.com/google/uuid"
2425
k8stypes "k8s.io/apimachinery/pkg/types"
2526
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2627
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
@@ -52,7 +53,7 @@ func TestFilter(t *testing.T) {
5253

5354
for _, test := range tests {
5455
t.Run(test.name, func(t *testing.T) {
55-
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
56+
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
5657
got := test.filter.Filter(ctx, test.input)
5758

5859
if diff := cmp.Diff(test.output, got); diff != "" {
@@ -187,7 +188,7 @@ func TestFilterFunc(t *testing.T) {
187188

188189
for _, test := range tests {
189190
t.Run(test.name, func(t *testing.T) {
190-
ctx := types.NewSchedulingContext(context.Background(), test.req, test.input)
191+
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
191192
got := test.f(ctx, test.input)
192193

193194
if diff := cmp.Diff(test.output, got); diff != "" {
@@ -221,6 +222,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
221222

222223
// Create a test request and pods
223224
req := &types.LLMRequest{
225+
RequestId: uuid.NewString(),
224226
Model: testAffinityModel,
225227
ResolvedTargetModel: testAffinityModel,
226228
}
@@ -244,7 +246,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
244246
},
245247
},
246248
}
247-
ctx := types.NewSchedulingContext(context.Background(), req, pods)
249+
ctx := types.NewSchedulingContext(context.Background(), req, nil, pods)
248250

249251
// Run the filter function multiple times and count the results
250252
affinityCount := 0

pkg/epp/scheduling/plugins/scorer/kvcache_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func TestKvCacheScorer(t *testing.T) {
8282

8383
for _, tt := range tests {
8484
t.Run(tt.name, func(t *testing.T) {
85-
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
85+
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
8686
scorer := &KVCacheScorer{}
8787
scores := scorer.Score(ctx, tt.pods)
8888

pkg/epp/scheduling/plugins/scorer/queue_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func TestQueueScorer(t *testing.T) {
7373

7474
for _, tt := range tests {
7575
t.Run(tt.name, func(t *testing.T) {
76-
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, tt.pods)
76+
ctx := types.NewSchedulingContext(context.Background(), &types.LLMRequest{}, nil, tt.pods)
7777
scores := scorer.Score(ctx, tt.pods)
7878

7979
for i, pod := range tt.pods {

pkg/epp/scheduling/scheduler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types
108108
// Snapshot pod metrics from the datastore to:
109109
// 1. Reduce concurrent access to the datastore.
110110
// 2. Ensure consistent data during the scheduling operation of a request.
111-
sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
111+
sCtx := types.NewSchedulingContext(ctx, req, nil, types.ToSchedulerPodMetrics(s.datastore.PodGetAll()))
112112
loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot))
113113

114114
return s.ScheduleWithContext(sCtx, req)

pkg/epp/scheduling/scheduler_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"testing"
2222

2323
"github.com/google/go-cmp/cmp"
24+
"github.com/google/uuid"
2425
k8stypes "k8s.io/apimachinery/pkg/types"
2526
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
2627
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
@@ -40,6 +41,7 @@ func TestSchedule(t *testing.T) {
4041
{
4142
name: "no pods in datastore",
4243
req: &types.LLMRequest{
44+
RequestId: uuid.NewString(),
4345
Model: "any-model",
4446
ResolvedTargetModel: "any-model",
4547
Critical: true,
@@ -50,6 +52,7 @@ func TestSchedule(t *testing.T) {
5052
{
5153
name: "critical request",
5254
req: &types.LLMRequest{
55+
RequestId: uuid.NewString(),
5356
Model: "critical",
5457
ResolvedTargetModel: "critical",
5558
Critical: true,
@@ -114,6 +117,7 @@ func TestSchedule(t *testing.T) {
114117
{
115118
name: "sheddable request, accepted",
116119
req: &types.LLMRequest{
120+
RequestId: uuid.NewString(),
117121
Model: "sheddable",
118122
ResolvedTargetModel: "sheddable",
119123
Critical: false,
@@ -177,6 +181,7 @@ func TestSchedule(t *testing.T) {
177181
{
178182
name: "sheddable request, dropped",
179183
req: &types.LLMRequest{
184+
RequestId: uuid.NewString(),
180185
Model: "sheddable",
181186
ResolvedTargetModel: "sheddable",
182187
Critical: false,
@@ -356,7 +361,9 @@ func TestSchedulePlugins(t *testing.T) {
356361
// Initialize the scheduler
357362
scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config)
358363

359-
req := &types.LLMRequest{Model: "test-model"}
364+
req := &types.LLMRequest{
365+
RequestId: uuid.NewString(),
366+
Model: "test-model"}
360367
got, err := scheduler.Schedule(context.Background(), req)
361368

362369
// Validate error state

pkg/epp/scheduling/types/types.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ import (
2828

2929
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
3030
type LLMRequest struct {
31+
// RequestId is the Envoy generated Id for the request being processed
32+
RequestId string
33+
3134
// Model is the name of the model that the user specified in the request body.
3235
Model string
3336
// ResolvedTargetModel is the final target model after traffic split.
@@ -45,6 +48,24 @@ func (r *LLMRequest) String() string {
4548
r.Model, r.ResolvedTargetModel, r.Critical, len(r.Prompt), r.Headers)
4649
}
4750

51+
// LLMResponse contains information from the response received to be passed to plugins
52+
type LLMResponse struct {
53+
// RequestId is the Envoy generated Id for the request being processed
54+
RequestId string
55+
56+
// Headers is a map of the response headers. Nil during body processing
57+
Headers map[string]string
58+
59+
// Body Is the body of the response or nil during header processing
60+
Body string
61+
62+
// IsStreaming indicates whether or not the response is being streamed by the model
63+
IsStreaming bool
64+
65+
// EndOfStream when true indicates that this invocation contains the last chunk of the response
66+
EndOfStream bool
67+
}
68+
4869
type Pod interface {
4970
GetPod() *backend.Pod
5071
GetMetrics() *backendmetrics.Metrics
@@ -61,6 +82,7 @@ type SchedulingContext struct {
6182
context.Context
6283
Logger logr.Logger
6384
Req *LLMRequest
85+
Resp *LLMResponse
6486
PodsSnapshot []Pod
6587
}
6688

@@ -84,12 +106,13 @@ type PodMetrics struct {
84106
*backendmetrics.Metrics
85107
}
86108

87-
func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext {
109+
func NewSchedulingContext(ctx context.Context, req *LLMRequest, resp *LLMResponse, pods []Pod) *SchedulingContext {
88110
logger := log.FromContext(ctx).WithValues("request", req)
89111
return &SchedulingContext{
90112
Context: ctx,
91113
Logger: logger,
92114
Req: req,
115+
Resp: resp,
93116
PodsSnapshot: pods,
94117
}
95118
}

0 commit comments

Comments
 (0)