Skip to content

Commit 65168db

Browse files
committed
Add unit tests for request body
1 parent 855436e commit 65168db

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed

pkg/epp/handlers/request_test.go

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package handlers
18+
19+
import (
20+
"context"
21+
"strings"
22+
"testing"
23+
"time"
24+
25+
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
26+
"github.com/google/go-cmp/cmp"
27+
corev1 "k8s.io/api/core/v1"
28+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
29+
"k8s.io/apimachinery/pkg/runtime"
30+
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
31+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
32+
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
33+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
34+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
35+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
36+
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
37+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
38+
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
39+
)
40+
41+
const (
42+
DefaultDestinationEndpointHintMetadataNamespace = "envoy.lb" // default for --destinationEndpointHintMetadataNamespace
43+
DefaultDestinationEndpointHintKey = "x-gateway-destination-endpoint" // default for --destinationEndpointHintKey
44+
)
45+
46+
func TestHandleRequestBody(t *testing.T) {
47+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
48+
49+
// Setup datastore
50+
tsModel := "food-review"
51+
modelWithTarget := "food-review-0"
52+
model1ts := testutil.MakeInferenceModel("model1").
53+
CreationTimestamp(metav1.Unix(1000, 0)).
54+
ModelName(tsModel).ObjRef()
55+
model2 := testutil.MakeInferenceModel("model2").
56+
CreationTimestamp(metav1.Unix(1000, 0)).
57+
ModelName(modelWithTarget).ObjRef()
58+
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
59+
ds := datastore.NewDatastore(t.Context(), pmf)
60+
ds.ModelSetIfOlder(model1ts)
61+
ds.ModelSetIfOlder(model2)
62+
63+
pool := &v1alpha2.InferencePool{
64+
Spec: v1alpha2.InferencePoolSpec{
65+
TargetPortNumber: int32(8000),
66+
Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{
67+
"some-key": "some-val",
68+
},
69+
},
70+
}
71+
pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}}
72+
scheme := runtime.NewScheme()
73+
_ = clientgoscheme.AddToScheme(scheme)
74+
fakeClient := fake.NewClientBuilder().
75+
WithScheme(scheme).
76+
Build()
77+
if err := ds.PoolSet(ctx, fakeClient, pool); err != nil {
78+
t.Error(err, "Error while setting inference pool")
79+
}
80+
ds.PodUpdateOrAddIfNotExist(pod)
81+
82+
tests := []struct {
83+
name string
84+
reqBodyMap map[string]interface{}
85+
reqCtx *RequestContext
86+
wantErrCode string
87+
wantErr bool
88+
wantReqCtx *RequestContext
89+
wantRespBody map[string]interface{}
90+
}{
91+
{
92+
name: "successful request",
93+
reqBodyMap: map[string]interface{}{
94+
"model": tsModel,
95+
"prompt": "test prompt",
96+
},
97+
wantReqCtx: &RequestContext{
98+
Model: tsModel,
99+
ResolvedTargetModel: tsModel,
100+
TargetPod: "/pod1",
101+
TargetEndpoint: "address-1:8000",
102+
},
103+
wantRespBody: map[string]interface{}{
104+
"model": tsModel,
105+
"prompt": "test prompt",
106+
},
107+
},
108+
{
109+
name: "successful request with target model",
110+
reqBodyMap: map[string]interface{}{
111+
"model": modelWithTarget,
112+
"prompt": "test prompt",
113+
},
114+
wantReqCtx: &RequestContext{
115+
Model: modelWithTarget,
116+
ResolvedTargetModel: modelWithTarget,
117+
TargetPod: "/pod1",
118+
TargetEndpoint: "address-1:8000",
119+
},
120+
wantRespBody: map[string]interface{}{
121+
"model": modelWithTarget,
122+
"prompt": "test prompt",
123+
},
124+
},
125+
{
126+
name: "no model defined, expect err",
127+
wantErr: true,
128+
wantErrCode: errutil.BadRequest,
129+
},
130+
{
131+
name: "invalid model defined, expect err",
132+
reqBodyMap: map[string]interface{}{
133+
"model": "non-existent-model",
134+
"prompt": "test prompt",
135+
},
136+
wantErr: true,
137+
wantErrCode: errutil.BadConfiguration,
138+
},
139+
{
140+
name: "invalid target defined, expect err",
141+
reqBodyMap: map[string]interface{}{
142+
"model": "food-review-1",
143+
"prompt": "test prompt",
144+
},
145+
wantErr: true,
146+
wantErrCode: errutil.BadConfiguration,
147+
},
148+
}
149+
150+
for _, test := range tests {
151+
t.Run(test.name, func(t *testing.T) {
152+
server := NewStreamingServer(scheduling.NewScheduler(ds), DefaultDestinationEndpointHintMetadataNamespace, DefaultDestinationEndpointHintKey, ds)
153+
reqCtx := test.reqCtx
154+
if reqCtx == nil {
155+
reqCtx = &RequestContext{}
156+
}
157+
req := &extProcPb.ProcessingRequest{}
158+
gotReqCtx, err := server.HandleRequestBody(ctx, reqCtx, req, test.reqBodyMap)
159+
160+
if test.wantErr {
161+
if err == nil {
162+
t.Fatalf("HandleRequestBody should have returned an error containing '%s', but got nil", test.wantErrCode)
163+
}
164+
if !strings.Contains(err.Error(), test.wantErrCode) {
165+
t.Fatalf("HandleRequestBody returned error '%v', which does not contain expected substring '%s'", err, test.wantErrCode)
166+
}
167+
return
168+
}
169+
170+
if err != nil {
171+
t.Fatalf("HandleRequestBody returned unexpected error: %v", err)
172+
}
173+
174+
if test.wantReqCtx != nil {
175+
if diff := cmp.Diff(test.wantReqCtx.Model, gotReqCtx.Model); diff != "" {
176+
t.Errorf("HandleRequestBody returned unexpected reqCtx.Model, diff(-want, +got): %v", diff)
177+
}
178+
if diff := cmp.Diff(test.wantReqCtx.ResolvedTargetModel, gotReqCtx.ResolvedTargetModel); diff != "" {
179+
t.Errorf("HandleRequestBody returned unexpected reqCtx.ResolvedTargetModel, diff(-want, +got): %v", diff)
180+
}
181+
if diff := cmp.Diff(test.wantReqCtx.TargetPod, gotReqCtx.TargetPod); diff != "" {
182+
t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetPod, diff(-want, +got): %v", diff)
183+
}
184+
if diff := cmp.Diff(test.wantReqCtx.TargetEndpoint, gotReqCtx.TargetEndpoint); diff != "" {
185+
t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetEndpoint, diff(-want, +got): %v", diff)
186+
}
187+
}
188+
})
189+
}
190+
}

0 commit comments

Comments
 (0)