Skip to content

Commit 95d3e27

Browse files
rlakhtakiakaushikmitr
authored andcommitted
Add unit tests for request body (kubernetes-sigs#745)
1 parent ef600a2 commit 95d3e27

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed

pkg/epp/handlers/request_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package handlers
18+
19+
import (
20+
"context"
21+
"strings"
22+
"testing"
23+
"time"
24+
25+
"github.com/google/go-cmp/cmp"
26+
corev1 "k8s.io/api/core/v1"
27+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28+
"k8s.io/apimachinery/pkg/runtime"
29+
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
30+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
31+
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
32+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
33+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
34+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
35+
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
36+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
37+
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
38+
)
39+
40+
const (
41+
DefaultDestinationEndpointHintMetadataNamespace = "envoy.lb" // default for --destinationEndpointHintMetadataNamespace
42+
DefaultDestinationEndpointHintKey = "x-gateway-destination-endpoint" // default for --destinationEndpointHintKey
43+
)
44+
45+
func TestHandleRequestBody(t *testing.T) {
46+
ctx := logutil.NewTestLoggerIntoContext(context.Background())
47+
48+
// Setup datastore
49+
tsModel := "food-review"
50+
modelWithTarget := "food-review-0"
51+
model1 := testutil.MakeInferenceModel("model1").
52+
CreationTimestamp(metav1.Unix(1000, 0)).
53+
ModelName(tsModel).ObjRef()
54+
model2 := testutil.MakeInferenceModel("model2").
55+
CreationTimestamp(metav1.Unix(1000, 0)).
56+
ModelName(modelWithTarget).ObjRef()
57+
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
58+
ds := datastore.NewDatastore(t.Context(), pmf)
59+
ds.ModelSetIfOlder(model1)
60+
ds.ModelSetIfOlder(model2)
61+
62+
pool := &v1alpha2.InferencePool{
63+
Spec: v1alpha2.InferencePoolSpec{
64+
TargetPortNumber: int32(8000),
65+
Selector: map[v1alpha2.LabelKey]v1alpha2.LabelValue{
66+
"some-key": "some-val",
67+
},
68+
},
69+
}
70+
pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}, Status: corev1.PodStatus{PodIP: "address-1"}}
71+
scheme := runtime.NewScheme()
72+
_ = clientgoscheme.AddToScheme(scheme)
73+
fakeClient := fake.NewClientBuilder().
74+
WithScheme(scheme).
75+
Build()
76+
if err := ds.PoolSet(ctx, fakeClient, pool); err != nil {
77+
t.Error(err, "Error while setting inference pool")
78+
}
79+
ds.PodUpdateOrAddIfNotExist(pod)
80+
81+
tests := []struct {
82+
name string
83+
reqBodyMap map[string]interface{}
84+
wantErrCode string
85+
wantReqCtx *RequestContext
86+
wantRespBody map[string]interface{}
87+
}{
88+
{
89+
name: "successful request",
90+
reqBodyMap: map[string]interface{}{
91+
"model": tsModel,
92+
"prompt": "test prompt",
93+
},
94+
wantReqCtx: &RequestContext{
95+
Model: tsModel,
96+
ResolvedTargetModel: tsModel,
97+
TargetPod: "/pod1",
98+
TargetEndpoint: "address-1:8000",
99+
},
100+
wantRespBody: map[string]interface{}{
101+
"model": tsModel,
102+
"prompt": "test prompt",
103+
},
104+
},
105+
{
106+
name: "successful request with target model",
107+
reqBodyMap: map[string]interface{}{
108+
"model": modelWithTarget,
109+
"prompt": "test prompt",
110+
},
111+
wantReqCtx: &RequestContext{
112+
Model: modelWithTarget,
113+
ResolvedTargetModel: modelWithTarget,
114+
TargetPod: "/pod1",
115+
TargetEndpoint: "address-1:8000",
116+
},
117+
wantRespBody: map[string]interface{}{
118+
"model": modelWithTarget,
119+
"prompt": "test prompt",
120+
},
121+
},
122+
{
123+
name: "no model defined, expect err",
124+
wantErrCode: errutil.BadRequest,
125+
},
126+
{
127+
name: "invalid model defined, expect err",
128+
reqBodyMap: map[string]interface{}{
129+
"model": "non-existent-model",
130+
"prompt": "test prompt",
131+
},
132+
wantErrCode: errutil.BadConfiguration,
133+
},
134+
{
135+
name: "invalid target defined, expect err",
136+
reqBodyMap: map[string]interface{}{
137+
"model": "food-review-1",
138+
"prompt": "test prompt",
139+
},
140+
wantErrCode: errutil.BadConfiguration,
141+
},
142+
}
143+
144+
for _, test := range tests {
145+
t.Run(test.name, func(t *testing.T) {
146+
server := NewStreamingServer(scheduling.NewScheduler(ds), DefaultDestinationEndpointHintMetadataNamespace, DefaultDestinationEndpointHintKey, ds)
147+
reqCtx := &RequestContext{
148+
Request: &Request{
149+
Body: test.reqBodyMap,
150+
},
151+
}
152+
reqCtx, err := server.HandleRequestBody(ctx, reqCtx)
153+
154+
if test.wantErrCode != "" {
155+
if err == nil {
156+
t.Fatalf("HandleRequestBody should have returned an error containing '%s', but got nil", test.wantErrCode)
157+
}
158+
if !strings.Contains(err.Error(), test.wantErrCode) {
159+
t.Fatalf("HandleRequestBody returned error '%v', which does not contain expected substring '%s'", err, test.wantErrCode)
160+
}
161+
return
162+
}
163+
164+
if err != nil {
165+
t.Fatalf("HandleRequestBody returned unexpected error: %v", err)
166+
}
167+
168+
if test.wantReqCtx != nil {
169+
if diff := cmp.Diff(test.wantReqCtx.Model, reqCtx.Model); diff != "" {
170+
t.Errorf("HandleRequestBody returned unexpected reqCtx.Model, diff(-want, +got): %v", diff)
171+
}
172+
if diff := cmp.Diff(test.wantReqCtx.ResolvedTargetModel, reqCtx.ResolvedTargetModel); diff != "" {
173+
t.Errorf("HandleRequestBody returned unexpected reqCtx.ResolvedTargetModel, diff(-want, +got): %v", diff)
174+
}
175+
if diff := cmp.Diff(test.wantReqCtx.TargetPod, reqCtx.TargetPod); diff != "" {
176+
t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetPod, diff(-want, +got): %v", diff)
177+
}
178+
if diff := cmp.Diff(test.wantReqCtx.TargetEndpoint, reqCtx.TargetEndpoint); diff != "" {
179+
t.Errorf("HandleRequestBody returned unexpected reqCtx.TargetEndpoint, diff(-want, +got): %v", diff)
180+
}
181+
}
182+
})
183+
}
184+
}

0 commit comments

Comments
 (0)