diff --git a/pkg/body-based-routing/server/runserver.go b/pkg/body-based-routing/server/runserver.go index 3674c6cf6..55e794222 100644 --- a/pkg/body-based-routing/server/runserver.go +++ b/pkg/body-based-routing/server/runserver.go @@ -32,7 +32,8 @@ import ( // ExtProcServerRunner provides methods to manage an external process server. type ExtProcServerRunner struct { - GrpcPort int + GrpcPort int + SecureServing bool } // Default values for CLI flags in main @@ -42,7 +43,8 @@ const ( func NewDefaultExtProcServerRunner() *ExtProcServerRunner { return &ExtProcServerRunner{ - GrpcPort: DefaultGrpcPort, + GrpcPort: DefaultGrpcPort, + SecureServing: true, } } @@ -50,18 +52,20 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { // The runnable implements LeaderElectionRunnable with leader election disabled. func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { return runnable.NoLeaderElection(manager.RunnableFunc(func(ctx context.Context) error { - cert, err := tlsutil.CreateSelfSignedTLSCertificate(logger) - if err != nil { - logger.Error(err, "Failed to create self signed certificate") - return err + var srv *grpc.Server + if r.SecureServing { + cert, err := tlsutil.CreateSelfSignedTLSCertificate(logger) + if err != nil { + logger.Error(err, "Failed to create self signed certificate") + return err + } + creds := credentials.NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}) + srv = grpc.NewServer(grpc.Creds(creds)) + } else { + srv = grpc.NewServer() } - creds := credentials.NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}) - srv := grpc.NewServer(grpc.Creds(creds)) - extProcPb.RegisterExternalProcessorServer( - srv, - handlers.NewServer(), - ) + extProcPb.RegisterExternalProcessorServer(srv, handlers.NewServer()) // Forward to the gRPC runnable. return runnable.GRPCServer("ext-proc", srv, r.GrpcPort).Start(ctx) diff --git a/test/integration/bbr/hermetic_test.go b/test/integration/bbr/hermetic_test.go new file mode 100644 index 000000000..be8b27213 --- /dev/null +++ b/test/integration/bbr/hermetic_test.go @@ -0,0 +1,173 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package bbr contains integration tests for the body-based routing extension. +package bbr + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/testing/protocmp" + runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/server" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const port = runserver.DefaultGrpcPort + +var logger = logutil.NewTestLogger().V(logutil.VERBOSE) + +func TestBodyBasedRouting(t *testing.T) { + tests := []struct { + name string + req *extProcPb.ProcessingRequest + wantHeaders []*configPb.HeaderValueOption + wantErr bool + }{ + { + name: "success adding model parameter to header", + req: generateRequest(logger, "llama"), + wantHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "X-Gateway-Model-Name", + RawValue: []byte("llama"), + }, + }, + }, + wantErr: false, + }, + { + name: "no model parameter", + req: generateRequest(logger, ""), + wantHeaders: []*configPb.HeaderValueOption{}, + wantErr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, cleanup := setUpHermeticServer() + t.Cleanup(cleanup) + + want := &extProcPb.ProcessingResponse{} + if len(test.wantHeaders) > 0 { + want.Response = &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: test.wantHeaders, + }, + ClearRouteCache: true, + }, + }, + } + } else { + want.Response = &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{}, + } + } + + res, err := sendRequest(t, client, test.req) + if err != nil && !test.wantErr { + t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) + } + if diff := cmp.Diff(want, res, protocmp.Transform()); diff != "" { + t.Errorf("Unexpected response, (-want +got): %v", diff) + } + }) + } +} + +func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { + serverCtx, stopServer := context.WithCancel(context.Background()) + serverRunner := runserver.NewDefaultExtProcServerRunner() + serverRunner.SecureServing = false + + go func() { + if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { + logutil.Fatal(logger, err, "Failed to start ext-proc server") + } + }() + + address := fmt.Sprintf("localhost:%v", port) + // Create a grpc connection + conn, err := grpc.NewClient(address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + logutil.Fatal(logger, err, "Failed to connect", "address", address) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + client, err = extProcPb.NewExternalProcessorClient(conn).Process(ctx) + if err != nil { + logutil.Fatal(logger, err, "Failed to create client") + } + return client, func() { + cancel() + conn.Close() + stopServer() + + // wait a little until the goroutines actually exit + time.Sleep(5 * time.Second) + } +} + +func generateRequest(logger logr.Logger, model string) *extProcPb.ProcessingRequest { + j := map[string]interface{}{ + "prompt": "test1", + "max_tokens": 100, + "temperature": 0, + } + if model != "" { + j["model"] = model + } + + llmReq, err := json.Marshal(j) + if err != nil { + logutil.Fatal(logger, err, "Failed to unmarshal LLM request") + } + req := &extProcPb.ProcessingRequest{ + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: llmReq}, + }, + } + return req +} + +func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { + t.Logf("Sending request: %v", req) + if err := client.Send(req); err != nil { + t.Logf("Failed to send request %+v: %v", req, err) + return nil, err + } + + res, err := client.Recv() + if err != nil { + t.Logf("Failed to receive: %v", err) + return nil, err + } + t.Logf("Received request %+v", res) + return res, err +}