@@ -19,20 +19,19 @@ package bbr
19
19
20
20
import (
21
21
"context"
22
- "encoding/json"
23
22
"fmt"
24
23
"testing"
25
24
"time"
26
25
27
26
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
28
27
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
29
- "github.com/go-logr/logr"
30
28
"github.com/google/go-cmp/cmp"
31
29
"google.golang.org/grpc"
32
30
"google.golang.org/grpc/credentials/insecure"
33
31
"google.golang.org/protobuf/testing/protocmp"
34
32
runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/server"
35
33
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
34
+ integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration"
36
35
)
37
36
38
37
var logger = logutil .NewTestLogger ().V (logutil .VERBOSE )
@@ -46,7 +45,7 @@ func TestBodyBasedRouting(t *testing.T) {
46
45
}{
47
46
{
48
47
name : "success adding model parameter to header" ,
49
- req : generateRequest (logger , "llama" ),
48
+ req : integrationutils . GenerateRequest (logger , "test" , "llama" ),
50
49
wantHeaders : []* configPb.HeaderValueOption {
51
50
{
52
51
Header : & configPb.HeaderValue {
@@ -59,15 +58,15 @@ func TestBodyBasedRouting(t *testing.T) {
59
58
},
60
59
{
61
60
name : "no model parameter" ,
62
- req : generateRequest (logger , "" ),
61
+ req : integrationutils . GenerateRequest (logger , "test1" , "" ),
63
62
wantHeaders : []* configPb.HeaderValueOption {},
64
63
wantErr : false ,
65
64
},
66
65
}
67
66
68
67
for _ , test := range tests {
69
68
t .Run (test .name , func (t * testing.T ) {
70
- client , cleanup := setUpHermeticServer ()
69
+ client , cleanup := setUpHermeticServer (false )
71
70
t .Cleanup (cleanup )
72
71
73
72
want := & extProcPb.ProcessingResponse {}
@@ -88,7 +87,7 @@ func TestBodyBasedRouting(t *testing.T) {
88
87
}
89
88
}
90
89
91
- res , err := sendRequest (t , client , test .req )
90
+ res , err := integrationutils . SendRequest (t , client , test .req )
92
91
if err != nil && ! test .wantErr {
93
92
t .Errorf ("Unexpected error, got: %v, want error: %v" , err , test .wantErr )
94
93
}
@@ -99,12 +98,171 @@ func TestBodyBasedRouting(t *testing.T) {
99
98
}
100
99
}
101
100
102
- func setUpHermeticServer () (client extProcPb.ExternalProcessor_ProcessClient , cleanup func ()) {
101
+ func TestFullDuplexStreamed_BodyBasedRouting (t * testing.T ) {
102
+ tests := []struct {
103
+ name string
104
+ reqs []* extProcPb.ProcessingRequest
105
+ wantResponses []* extProcPb.ProcessingResponse
106
+ wantErr bool
107
+ }{
108
+ {
109
+ name : "success adding model parameter to header" ,
110
+ reqs : integrationutils .GenerateStreamedRequestSet (logger , "test" , "foo" ),
111
+ wantResponses : []* extProcPb.ProcessingResponse {
112
+ {
113
+ Response : & extProcPb.ProcessingResponse_RequestHeaders {
114
+ RequestHeaders : & extProcPb.HeadersResponse {
115
+ Response : & extProcPb.CommonResponse {
116
+ ClearRouteCache : true ,
117
+ HeaderMutation : & extProcPb.HeaderMutation {
118
+ SetHeaders : []* configPb.HeaderValueOption {
119
+ {
120
+ Header : & configPb.HeaderValue {
121
+ Key : "X-Gateway-Model-Name" ,
122
+ RawValue : []byte ("foo" ),
123
+ },
124
+ },
125
+ }},
126
+ },
127
+ },
128
+ },
129
+ },
130
+ {
131
+ Response : & extProcPb.ProcessingResponse_RequestBody {
132
+ RequestBody : & extProcPb.BodyResponse {
133
+ Response : & extProcPb.CommonResponse {
134
+ BodyMutation : & extProcPb.BodyMutation {
135
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
136
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
137
+ Body : []byte ("{\" max_tokens\" :100,\" model\" :\" foo\" ,\" prompt\" :\" test\" ,\" temperature\" :0}" ),
138
+ EndOfStream : true ,
139
+ },
140
+ },
141
+ },
142
+ },
143
+ },
144
+ },
145
+ },
146
+ },
147
+ },
148
+ {
149
+ name : "success adding model parameter to header with multiple body chunks" ,
150
+ reqs : []* extProcPb.ProcessingRequest {
151
+ {
152
+ Request : & extProcPb.ProcessingRequest_RequestHeaders {
153
+ RequestHeaders : & extProcPb.HttpHeaders {
154
+ Headers : & configPb.HeaderMap {
155
+ Headers : []* configPb.HeaderValue {
156
+ {
157
+ Key : "hi" ,
158
+ Value : "mom" ,
159
+ },
160
+ },
161
+ },
162
+ },
163
+ },
164
+ },
165
+ {
166
+ Request : & extProcPb.ProcessingRequest_RequestBody {
167
+ RequestBody : & extProcPb.HttpBody {Body : []byte ("{\" max_tokens\" :100,\" model\" :\" sql-lo" ), EndOfStream : false },
168
+ },
169
+ },
170
+ {
171
+ Request : & extProcPb.ProcessingRequest_RequestBody {
172
+ RequestBody : & extProcPb.HttpBody {Body : []byte ("ra-sheddable\" ,\" prompt\" :\" test\" ,\" temperature\" :0}" ), EndOfStream : true },
173
+ },
174
+ },
175
+ },
176
+ wantResponses : []* extProcPb.ProcessingResponse {
177
+ {
178
+ Response : & extProcPb.ProcessingResponse_RequestHeaders {
179
+ RequestHeaders : & extProcPb.HeadersResponse {
180
+ Response : & extProcPb.CommonResponse {
181
+ ClearRouteCache : true ,
182
+ HeaderMutation : & extProcPb.HeaderMutation {
183
+ SetHeaders : []* configPb.HeaderValueOption {
184
+ {
185
+ Header : & configPb.HeaderValue {
186
+ Key : "X-Gateway-Model-Name" ,
187
+ RawValue : []byte ("sql-lora-sheddable" ),
188
+ },
189
+ },
190
+ }},
191
+ },
192
+ },
193
+ },
194
+ },
195
+ {
196
+ Response : & extProcPb.ProcessingResponse_RequestBody {
197
+ RequestBody : & extProcPb.BodyResponse {
198
+ Response : & extProcPb.CommonResponse {
199
+ BodyMutation : & extProcPb.BodyMutation {
200
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
201
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
202
+ Body : []byte ("{\" max_tokens\" :100,\" model\" :\" sql-lora-sheddable\" ,\" prompt\" :\" test\" ,\" temperature\" :0}" ),
203
+ EndOfStream : true ,
204
+ },
205
+ },
206
+ },
207
+ },
208
+ },
209
+ },
210
+ },
211
+ },
212
+ },
213
+ {
214
+ name : "no model parameter" ,
215
+ reqs : integrationutils .GenerateStreamedRequestSet (logger , "test" , "" ),
216
+ wantResponses : []* extProcPb.ProcessingResponse {
217
+ {
218
+ Response : & extProcPb.ProcessingResponse_RequestHeaders {
219
+ RequestHeaders : & extProcPb.HeadersResponse {},
220
+ },
221
+ },
222
+ {
223
+ Response : & extProcPb.ProcessingResponse_RequestBody {
224
+ RequestBody : & extProcPb.BodyResponse {
225
+ Response : & extProcPb.CommonResponse {
226
+ BodyMutation : & extProcPb.BodyMutation {
227
+ Mutation : & extProcPb.BodyMutation_StreamedResponse {
228
+ StreamedResponse : & extProcPb.StreamedBodyResponse {
229
+ Body : []byte ("{\" max_tokens\" :100,\" prompt\" :\" test\" ,\" temperature\" :0}" ),
230
+ EndOfStream : true ,
231
+ },
232
+ },
233
+ },
234
+ },
235
+ },
236
+ },
237
+ },
238
+ },
239
+ },
240
+ }
241
+
242
+ for _ , test := range tests {
243
+ t .Run (test .name , func (t * testing.T ) {
244
+ client , cleanup := setUpHermeticServer (true )
245
+ t .Cleanup (cleanup )
246
+
247
+ responses , err := integrationutils .StreamedRequest (t , client , test .reqs , len (test .wantResponses ))
248
+ if err != nil && ! test .wantErr {
249
+ t .Errorf ("Unexpected error, got: %v, want error: %v" , err , test .wantErr )
250
+ }
251
+
252
+ if diff := cmp .Diff (test .wantResponses , responses , protocmp .Transform ()); diff != "" {
253
+ t .Errorf ("Unexpected response, (-want +got): %v" , diff )
254
+ }
255
+ })
256
+ }
257
+ }
258
+
259
+ func setUpHermeticServer (streaming bool ) (client extProcPb.ExternalProcessor_ProcessClient , cleanup func ()) {
103
260
port := 9004
104
261
105
262
serverCtx , stopServer := context .WithCancel (context .Background ())
106
263
serverRunner := runserver .NewDefaultExtProcServerRunner (port , false )
107
264
serverRunner .SecureServing = false
265
+ serverRunner .Streaming = streaming
108
266
109
267
go func () {
110
268
if err := serverRunner .AsRunnable (logger .WithName ("ext-proc" )).Start (serverCtx ); err != nil {
@@ -133,41 +291,3 @@ func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cl
133
291
time .Sleep (5 * time .Second )
134
292
}
135
293
}
136
-
137
- func generateRequest (logger logr.Logger , model string ) * extProcPb.ProcessingRequest {
138
- j := map [string ]interface {}{
139
- "prompt" : "test1" ,
140
- "max_tokens" : 100 ,
141
- "temperature" : 0 ,
142
- }
143
- if model != "" {
144
- j ["model" ] = model
145
- }
146
-
147
- llmReq , err := json .Marshal (j )
148
- if err != nil {
149
- logutil .Fatal (logger , err , "Failed to unmarshal LLM request" )
150
- }
151
- req := & extProcPb.ProcessingRequest {
152
- Request : & extProcPb.ProcessingRequest_RequestBody {
153
- RequestBody : & extProcPb.HttpBody {Body : llmReq },
154
- },
155
- }
156
- return req
157
- }
158
-
159
- func sendRequest (t * testing.T , client extProcPb.ExternalProcessor_ProcessClient , req * extProcPb.ProcessingRequest ) (* extProcPb.ProcessingResponse , error ) {
160
- t .Logf ("Sending request: %v" , req )
161
- if err := client .Send (req ); err != nil {
162
- t .Logf ("Failed to send request %+v: %v" , req , err )
163
- return nil , err
164
- }
165
-
166
- res , err := client .Recv ()
167
- if err != nil {
168
- t .Logf ("Failed to receive: %v" , err )
169
- return nil , err
170
- }
171
- t .Logf ("Received request %+v" , res )
172
- return res , err
173
- }
0 commit comments