Skip to content

Commit cd84b7e

Browse files
committed
Add unit tests for request body
1 parent 2d2db35 commit cd84b7e

22 files changed

+807
-146
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ FROM ${BUILDER_IMAGE} AS builder
88
ENV CGO_ENABLED=0
99
ENV GOOS=linux
1010
ENV GOARCH=amd64
11+
ARG COMMIT_SHA=unknown
1112

1213
# Dependencies
1314
WORKDIR /src
@@ -19,9 +20,8 @@ COPY cmd ./cmd
1920
COPY pkg ./pkg
2021
COPY internal ./internal
2122
COPY api ./api
22-
COPY .git ./.git
2323
WORKDIR /src/cmd/epp
24-
RUN go build -buildvcs=true -o /epp
24+
RUN go build -ldflags="-X sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics.CommitSHA=${COMMIT_SHA}" -o /epp
2525

2626
## Multistage deploy
2727
FROM ${BASE_IMAGE}

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ CONTAINER_TOOL ?= docker
2121
SHELL = /usr/bin/env bash -o pipefail
2222
.SHELLFLAGS = -ec
2323

24+
GIT_COMMIT_SHA ?= "$(shell git rev-parse HEAD 2>/dev/null)"
2425
GIT_TAG ?= $(shell git describe --tags --dirty --always)
2526
PLATFORMS ?= linux/amd64
2627
DOCKER_BUILDX_CMD ?= docker buildx
@@ -175,6 +176,7 @@ image-build: ## Build the EPP image using Docker Buildx.
175176
--platform=$(PLATFORMS) \
176177
--build-arg BASE_IMAGE=$(BASE_IMAGE) \
177178
--build-arg BUILDER_IMAGE=$(BUILDER_IMAGE) \
179+
--build-arg COMMIT_SHA=${GIT_COMMIT_SHA} \
178180
$(PUSH) \
179181
$(LOAD) \
180182
$(IMAGE_BUILD_EXTRA_OPTS) ./

cloudbuild.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ steps:
1212
- GIT_TAG=$_GIT_TAG
1313
- EXTRA_TAG=$_PULL_BASE_REF
1414
- DOCKER_BUILDX_CMD=/buildx-entrypoint
15+
- GIT_COMMIT_SHA=$COMMIT_SHA
1516
- name: gcr.io/k8s-staging-test-infra/gcb-docker-gcloud:v20240718-5ef92b5c36
1617
entrypoint: make
1718
args:

pkg/bbr/handlers/server.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"google.golang.org/grpc/status"
2929
"sigs.k8s.io/controller-runtime/pkg/log"
3030
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
31+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3132
)
3233

3334
func NewServer(streaming bool) *Server {
@@ -74,6 +75,11 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
7475
// If streaming and the body is not empty, then headers are handled when processing request body.
7576
loggerVerbose.Info("Received headers, passing off header processing until body arrives...")
7677
} else {
78+
if requestId := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestId) > 0 {
79+
logger = logger.WithValues(requtil.RequestIdHeaderKey, requestId)
80+
loggerVerbose = logger.V(logutil.VERBOSE)
81+
ctx = log.IntoContext(ctx, logger)
82+
}
7783
responses, err = s.HandleRequestHeaders(req.GetRequestHeaders())
7884
}
7985
case *extProcPb.ProcessingRequest_RequestBody:

pkg/epp/handlers/request.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ import (
3535
func (s *StreamingServer) HandleRequestBody(
3636
ctx context.Context,
3737
reqCtx *RequestContext,
38-
req *extProcPb.ProcessingRequest,
39-
requestBodyMap map[string]interface{},
4038
) (*RequestContext, error) {
4139
var requestBodyBytes []byte
4240
logger := log.FromContext(ctx)
41+
requestBodyMap := reqCtx.Request.Body
4342

4443
// Resolve target models.
4544
model, ok := requestBodyMap["model"].(string)
@@ -152,6 +151,15 @@ func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *Requ
152151
}
153152
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
154153
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
154+
return nil
155+
}
156+
157+
for _, header := range req.RequestHeaders.Headers.Headers {
158+
if header.RawValue != nil {
159+
reqCtx.Request.Headers[header.Key] = string(header.RawValue)
160+
} else {
161+
reqCtx.Request.Headers[header.Key] = header.Value
162+
}
155163
}
156164
return nil
157165
}

pkg/epp/handlers/request_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package handlers
18+
19+
import (
20+
"context"
21+
"strings"
22+
"testing"
23+
"time"
24+
25+
"github.com/google/go-cmp/cmp"
26+
corev1 "k8s.io/api/core/v1"
27+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28+
"k8s.io/apimachinery/pkg/runtime"
29+
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
30+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
31+
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
32+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
33+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
34+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
35+
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
36+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
37+
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
38+
)
39+
40+
const (
41+
DefaultDestinationEndpointHintMetadataNamespace = "envoy.lb" // default for --destinationEndpointHintMetadataNamespace
42+
DefaultDestinationEndpointHintKey = "x-gateway-destination-endpoint" // default for --destinationEndpointHintKey
43+
)
44+
45+
func TestHandleRequestBody(t *testing.T) {
46+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
47+
48+
// Setup datastore
49+
tsModel := "food-review"
50+
modelWithTarget := "food-review-0"
51+
model1 := testutil.MakeInferenceModel("model1").
52+
CreationTimestamp(metav1.Unix(1000, 0)).
53+
ModelName(tsModel).ObjRef()
54+
model2 := testutil.MakeInferenceModel("model2").
55+
CreationTimestamp(metav1.Unix(1000, 0)).
56+
ModelName(modelWithTarget).ObjRef()
57+
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
58+
ds := datastore.NewDatastore(t.Context(), pmf)
59+
ds.ModelSetIfOlder(model1)
60+
ds.ModelSetIfOlder(model2)
61+
62+
pool := &v1alpha2.InferencePool{
63+
Spec: v1alpha2.InferencePoolSpec{
64+
TargetPortNumber: int32(8000),
65+
Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{
66+
"some-key": "some-val",
67+
},
68+
},
69+
}
70+
pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}}
71+
scheme := runtime.NewScheme()
72+
_ = clientgoscheme.AddToScheme(scheme)
73+
fakeClient := fake.NewClientBuilder().
74+
WithScheme(scheme).
75+
Build()
76+
if err := ds.PoolSet(ctx, fakeClient, pool); err != nil {
77+
t.Error(err, "Error while setting inference pool")
78+
}
79+
ds.PodUpdateOrAddIfNotExist(pod)
80+
81+
tests := []struct {
82+
name string
83+
reqBodyMap map[string]interface{}
84+
wantErrCode string
85+
wantReqCtx *RequestContext
86+
wantRespBody map[string]interface{}
87+
}{
88+
{
89+
name: "successful request",
90+
reqBodyMap: map[string]interface{}{
91+
"model": tsModel,
92+
"prompt": "test prompt",
93+
},
94+
wantReqCtx: &RequestContext{
95+
Model: tsModel,
96+
ResolvedTargetModel: tsModel,
97+
TargetPod: "/pod1",
98+
TargetEndpoint: "address-1:8000",
99+
},
100+
wantRespBody: map[string]interface{}{
101+
"model": tsModel,
102+
"prompt": "test prompt",
103+
},
104+
},
105+
{
106+
name: "successful request with target model",
107+
reqBodyMap: map[string]interface{}{
108+
"model": modelWithTarget,
109+
"prompt": "test prompt",
110+
},
111+
wantReqCtx: &RequestContext{
112+
Model: modelWithTarget,
113+
ResolvedTargetModel: modelWithTarget,
114+
TargetPod: "/pod1",
115+
TargetEndpoint: "address-1:8000",
116+
},
117+
wantRespBody: map[string]interface{}{
118+
"model": modelWithTarget,
119+
"prompt": "test prompt",
120+
},
121+
},
122+
{
123+
name: "no model defined, expect err",
124+
wantErrCode: errutil.BadRequest,
125+
},
126+
{
127+
name: "invalid model defined, expect err",
128+
reqBodyMap: map[string]interface{}{
129+
"model": "non-existent-model",
130+
"prompt": "test prompt",
131+
},
132+
wantErrCode: errutil.BadConfiguration,
133+
},
134+
{
135+
name: "invalid target defined, expect err",
136+
reqBodyMap: map[string]interface{}{
137+
"model": "food-review-1",
138+
"prompt": "test prompt",
139+
},
140+
wantErrCode: errutil.BadConfiguration,
141+
},
142+
}
143+
144+
for _, test := range tests {
145+
t.Run(test.name, func(t *testing.T) {
146+
server := NewStreamingServer(scheduling.NewScheduler(ds), DefaultDestinationEndpointHintMetadataNamespace, DefaultDestinationEndpointHintKey, ds)
147+
reqCtx := &RequestContext{
148+
Request: &Request{
149+
Body: test.reqBodyMap,
150+
},
151+
}
152+
reqCtx, err := server.HandleRequestBody(ctx, reqCtx)
153+
154+
if test.wantErrCode != "" {
155+
if err == nil {
156+
t.Fatalf("HandleRequestBody should have returned an error containing '%s', but got nil", test.wantErrCode)
157+
}
158+
if !strings.Contains(err.Error(), test.wantErrCode) {
159+
t.Fatalf("HandleRequestBody returned error '%v', which does not contain expected substring '%s'", err, test.wantErrCode)
160+
}
161+
return
162+
}
163+
164+
if err != nil {
165+
t.Fatalf("HandleRequestBody returned unexpected error: %v", err)
166+
}
167+
168+
if test.wantReqCtx != nil {
169+
if diff := cmp.Diff(test.wantReqCtx.Model, reqCtx.Model); diff != "" {
170+
t.Errorf("HandleRequestBody returned unexpected reqCtx.Model, diff(-want, +got): %v", diff)
171+
}
172+
if diff := cmp.Diff(test.wantReqCtx.ResolvedTargetModel, reqCtx.ResolvedTargetModel); diff != "" {
173+
t.Errorf("HandleRequestBody returned unexpected reqCtx.ResolvedTargetModel, diff(-want, +got): %v", diff)
174+
}
175+
if diff := cmp.Diff(test.wantReqCtx.TargetPod, reqCtx.TargetPod); diff != "" {
176+
t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetPod, diff(-want, +got): %v", diff)
177+
}
178+
if diff := cmp.Diff(test.wantReqCtx.TargetEndpoint, reqCtx.TargetEndpoint); diff != "" {
179+
t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetEndpoint, diff(-want, +got): %v", diff)
180+
}
181+
}
182+
})
183+
}
184+
}

pkg/epp/handlers/server.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4141
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
4242
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
43+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
4344
)
4445

4546
func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore datastore.Datastore) *StreamingServer {
@@ -82,6 +83,7 @@ type RequestContext struct {
8283
ResponseComplete bool
8384
ResponseStatusCode string
8485
RequestRunning bool
86+
Request *Request
8587

8688
RequestState StreamRequestState
8789
modelServerStreaming bool
@@ -95,6 +97,10 @@ type RequestContext struct {
9597
respTrailerResp *extProcPb.ProcessingResponse
9698
}
9799

100+
type Request struct {
101+
Headers map[string]string
102+
Body map[string]interface{}
103+
}
98104
type StreamRequestState int
99105

100106
const (
@@ -118,10 +124,14 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
118124
// See https://github.com/envoyproxy/envoy/issues/17540.
119125
reqCtx := &RequestContext{
120126
RequestState: RequestReceived,
127+
Request: &Request{
128+
Headers: make(map[string]string),
129+
Body: make(map[string]interface{}),
130+
},
121131
}
122132

123133
var body []byte
124-
var requestBody, responseBody map[string]interface{}
134+
var responseBody map[string]interface{}
125135

126136
// Create error handling var as each request should only report once for
127137
// error metrics. This doesn't cover the error "Cannot receive stream request" because
@@ -158,6 +168,11 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
158168

159169
switch v := req.Request.(type) {
160170
case *extProcPb.ProcessingRequest_RequestHeaders:
171+
if requestId := requtil.ExtractHeaderValue(v, requtil.RequestIdHeaderKey); len(requestId) > 0 {
172+
logger = logger.WithValues(requtil.RequestIdHeaderKey, requestId)
173+
loggerTrace = logger.V(logutil.TRACE)
174+
ctx = log.IntoContext(ctx, logger)
175+
}
161176
err = s.HandleRequestHeaders(ctx, reqCtx, v)
162177
case *extProcPb.ProcessingRequest_RequestBody:
163178
loggerTrace.Info("Incoming body chunk", "EoS", v.RequestBody.EndOfStream)
@@ -167,15 +182,17 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
167182
// Message is buffered, we can read and decode.
168183
if v.RequestBody.EndOfStream {
169184
loggerTrace.Info("decoding")
170-
err = json.Unmarshal(body, &requestBody)
185+
err = json.Unmarshal(body, &reqCtx.Request.Body)
171186
if err != nil {
172187
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
188+
// TODO: short circuit and send the body back as is (this could be an envoy error), currently we drop
189+
// whatever the body request would have been and send our immediate response instead.
173190
}
174191

175192
// Body stream complete. Allocate empty slice for response to use.
176193
body = []byte{}
177194

178-
reqCtx, err = s.HandleRequestBody(ctx, reqCtx, req, requestBody)
195+
reqCtx, err = s.HandleRequestBody(ctx, reqCtx)
179196
if err != nil {
180197
logger.V(logutil.DEFAULT).Error(err, "Error handling body")
181198
} else {
@@ -256,7 +273,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
256273
loggerTrace.Info("stream completed")
257274
// Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes.
258275
// We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message.
259-
// using the standard 'err' var will send an immediate error response back to the caller.
276+
// Using the standard 'err' var will send an immediate error response back to the caller.
260277
var responseErr error
261278
responseErr = json.Unmarshal(body, &responseBody)
262279
if responseErr != nil {

0 commit comments

Comments
 (0)