diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index af31da429..8ada3e64d 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -20,10 +20,8 @@ import ( "context" "errors" "fmt" - "math/rand" "sync" - "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" @@ -304,35 +302,6 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha2.LabelKey]v1alpha2.LabelV return outMap } -func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { - source := rand.NewSource(rand.Int63()) - if seed > 0 { - source = rand.NewSource(seed) - } - r := rand.New(source) - - // all the weight values are nil, then we should return random model name - if model.Spec.TargetModels[0].Weight == nil { - index := r.Int31n(int32(len(model.Spec.TargetModels))) - return model.Spec.TargetModels[index].Name - } - - var weights int32 - for _, model := range model.Spec.TargetModels { - weights += *model.Weight - } - logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) - randomVal := r.Int31n(weights) - // TODO: optimize this without using loop - for _, model := range model.Spec.TargetModels { - if randomVal < *model.Weight { - return model.Name - } - randomVal -= *model.Weight - } - return "" -} - func IsCritical(model *v1alpha2.InferenceModel) bool { if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha2.Critical { return true diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index f60a4cc9b..1a88e5dc0 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -30,7 +30,6 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" ) @@ -223,113 +222,6 @@ func TestModel(t *testing.T) { } } -func TestRandomWeightedDraw(t *testing.T) { - logger := logutil.NewTestLogger() - tests := []struct { - name string - model *v1alpha2.InferenceModel - want string - }{ - { - name: "'random' distribution", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - Weight: pointer(50), - }, - { - Name: "v1", - Weight: pointer(50), - }, - }, - }, - }, - want: "canary", - }, - { - name: "'random' distribution", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - Weight: pointer(25), - }, - { - Name: "v1.1", - Weight: pointer(55), - }, - { - Name: "v1", - Weight: pointer(50), - }, - }, - }, - }, - want: "v1", - }, - { - name: "'random' distribution", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - Weight: pointer(20), - }, - { - Name: "v1.1", - Weight: pointer(20), - }, - { - Name: "v1", - Weight: pointer(10), - }, - }, - }, - }, - want: "v1.1", - }, - { - name: "weighted distribution with weight unset", - model: &v1alpha2.InferenceModel{ - Spec: v1alpha2.InferenceModelSpec{ - TargetModels: []v1alpha2.TargetModel{ - { - Name: "canary", - }, - { - Name: "v1.1", - }, - { - Name: "v1", - }, - }, - }, - }, - want: "canary", - }, - } - var seedVal int64 = 420 - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - for range 10000 { - model := RandomWeightedDraw(logger, test.model, seedVal) - if model != test.want { - t.Errorf("Model returned: %v != %v", model, test.want) - break - } - } - }) - } -} - -func pointer(v int32) *int32 { - return &v -} - var ( pod1 = &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 12afe4d74..d7678fadf 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -69,7 +69,7 @@ func (s *Server) HandleRequestBody( return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} } if len(modelObj.Spec.TargetModels) > 0 { - modelName = datastore.RandomWeightedDraw(logger, modelObj, 0) + modelName = RandomWeightedDraw(logger, modelObj, 0) if modelName == "" { return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} } diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 1452fdd2c..79ad7a6a0 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -85,9 +85,7 @@ func (s *Server) HandleResponseHeaders( if header.Key == "content-type" { contentType := header.RawValue if strings.Contains(string(contentType), "text/event-stream") { - reqCtx.Streaming = true - } else { - reqCtx.Streaming = false + reqCtx.modelServerStreaming = true } typeFound = true } @@ -155,7 +153,7 @@ func (s *Server) HandleResponseBody( loggerVerbose := logger.V(logutil.VERBOSE) body := req.Request.(*extProcPb.ProcessingRequest_ResponseBody) - if reqCtx.Streaming { + if reqCtx.modelServerStreaming { logger.V(logutil.DEBUG).Info("Processing HandleResponseBody") if err := s.HandleStreaming(ctx, reqCtx, body, loggerVerbose); err != nil { return nil, err @@ -189,7 +187,7 @@ func (s *Server) HandleNonStreaming( if err := json.Unmarshal(body.ResponseBody.Body, &res); err != nil { return errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("unmarshaling response body: %v", err)} } - reqCtx.Response = res + reqCtx.Usage = res.Usage reqCtx.ResponseSize = len(body.ResponseBody.Body) reqCtx.ResponseComplete = true loggerVerbose.Info("Response generated", "response", res) @@ -205,7 +203,7 @@ func (s *Server) HandleStreaming( responseText := string(body.ResponseBody.Body) if strings.Contains(responseText, streamingEndMsg) { parsedResp := ParseRespForUsage(ctx, responseText, loggerVerbose) - reqCtx.Response = parsedResp + reqCtx.Usage = parsedResp.Usage } if body.ResponseBody.EndOfStream { diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 8b6f16a7a..edfa3edba 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -65,7 +65,7 @@ func TestHandleResponseBody(t *testing.T) { name string req *extProcPb.ProcessingRequest_ResponseBody reqCtx *RequestContext - want Response + want Usage wantErr bool }{ { @@ -75,12 +75,10 @@ func TestHandleResponseBody(t *testing.T) { Body: []byte(body), }, }, - want: Response{ - Usage: Usage{ - PromptTokens: 11, - TotalTokens: 111, - CompletionTokens: 100, - }, + want: Usage{ + PromptTokens: 11, + TotalTokens: 111, + CompletionTokens: 100, }, }, { @@ -100,7 +98,7 @@ func TestHandleResponseBody(t *testing.T) { }, }, reqCtx: &RequestContext{ - Streaming: true, + modelServerStreaming: true, }, wantErr: false, // In the middle of streaming response, so request context response is not set yet. @@ -113,15 +111,13 @@ func TestHandleResponseBody(t *testing.T) { }, }, reqCtx: &RequestContext{ - Streaming: true, + modelServerStreaming: true, }, wantErr: false, - want: Response{ - Usage: Usage{ - PromptTokens: 7, - TotalTokens: 17, - CompletionTokens: 10, - }, + want: Usage{ + PromptTokens: 7, + TotalTokens: 17, + CompletionTokens: 10, }, }, } @@ -141,7 +137,7 @@ func TestHandleResponseBody(t *testing.T) { return } - if diff := cmp.Diff(test.want, reqCtx.Response); diff != "" { + if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" { t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) } }) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 4f45ae82b..cd354c2f5 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -128,10 +128,10 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { reqCtx.ResponseCompleteTimestamp = time.Now() metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) - metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Response.Usage.CompletionTokens) + metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens) } - if reqCtx.Streaming { + if reqCtx.modelServerStreaming { logger.V(logutil.DEBUG).Info("Request context after HandleResponseBody", "context", reqCtx) } else { loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx) @@ -149,7 +149,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { } } - if !reqCtx.Streaming { + if !reqCtx.modelServerStreaming { loggerVerbose.Info("Response generated", "response", resp) } else { logger.V(logutil.DEBUG).Info("Response generated", "response", resp) @@ -224,9 +224,32 @@ type RequestContext struct { RequestReceivedTimestamp time.Time ResponseCompleteTimestamp time.Time RequestSize int - Response Response + Usage Usage ResponseSize int ResponseComplete bool ResponseStatusCode string - Streaming bool + + RequestState StreamRequestState + modelServerStreaming bool + + reqHeaderResp *extProcPb.ProcessingResponse + reqBodyResp *extProcPb.ProcessingResponse + reqTrailerResp *extProcPb.ProcessingResponse + + respHeaderResp *extProcPb.ProcessingResponse + respBodyResp *extProcPb.ProcessingResponse + respTrailerResp *extProcPb.ProcessingResponse } + +type StreamRequestState int + +const ( + RequestReceived StreamRequestState = 0 + HeaderRequestResponseComplete StreamRequestState = 1 + BodyRequestResponsesComplete StreamRequestState = 2 + TrailerRequestResponsesComplete StreamRequestState = 3 + ResponseRecieved StreamRequestState = 4 + HeaderResponseResponseComplete StreamRequestState = 5 + BodyResponseResponsesComplete StreamRequestState = 6 + TrailerResponseResponsesComplete StreamRequestState = 7 +) diff --git a/pkg/epp/handlers/streamingserver.go b/pkg/epp/handlers/streamingserver.go index 684a7542c..64f9c03be 100644 --- a/pkg/epp/handlers/streamingserver.go +++ b/pkg/epp/handlers/streamingserver.go @@ -1,3 +1,19 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package handlers import ( @@ -5,6 +21,7 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "strconv" "strings" "time" @@ -16,6 +33,8 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" @@ -51,13 +70,13 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) // Create request context to share states during life time of an HTTP request. // See https://github.com/envoyproxy/envoy/issues/17540. - reqCtx := &StreamingRequestContext{ + reqCtx := &RequestContext{ RequestState: RequestReceived, } var body []byte - var requestBody, responseBody map[string]interface{} + // Create error handling var as each request should only report once for // error metrics. This doesn't cover the error "Cannot receive stream request" because // such errors might happen even though response is processed. @@ -90,8 +109,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: - reqCtx.RequestReceivedTimestamp = time.Now() - // Do nothing. Header info is handled in the HandleRequestBody func + err = s.HandleRequestHeaders(ctx, reqCtx, v) case *extProcPb.ProcessingRequest_RequestBody: loggerVerbose.Info("Incoming body chunk", "body", string(v.RequestBody.Body), "EoS", v.RequestBody.EndOfStream) // In the stream case, we can receive multiple request bodies. @@ -237,7 +255,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) // updateStateAndSendIfNeeded checks state and can send mutiple responses in a single pass, but only if ordered properly. // Order of requests matter in FULL_DUPLEX_STREAMING. For both request and response, the order of response sent back MUST be: Header->Body->Trailer, with trailer being optional. -func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProcessor_ProcessServer, loggerVerbose logr.Logger) error { +func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProcessor_ProcessServer, loggerVerbose logr.Logger) error { // No switch statement as we could send multiple responses in one pass. if r.RequestState == RequestReceived && r.reqHeaderResp != nil { loggerVerbose.Info("Request header response", "obj", r.reqHeaderResp) @@ -291,51 +309,13 @@ func (r *StreamingRequestContext) updateStateAndSendIfNeeded(srv extProcPb.Exter return nil } -type StreamingRequestContext struct { - TargetPod string - TargetEndpoint string - Model string - ResolvedTargetModel string - RequestState StreamRequestState - RequestReceivedTimestamp time.Time - ResponseCompleteTimestamp time.Time - RequestSize int - Usage Usage - ResponseSize int - ResponseComplete bool - ResponseStatusCode string - - modelServerStreaming bool - - reqHeaderResp *extProcPb.ProcessingResponse - reqBodyResp *extProcPb.ProcessingResponse - reqTrailerResp *extProcPb.ProcessingResponse - - respHeaderResp *extProcPb.ProcessingResponse - respBodyResp *extProcPb.ProcessingResponse - respTrailerResp *extProcPb.ProcessingResponse -} - -type StreamRequestState int - -const ( - RequestReceived StreamRequestState = 0 - HeaderRequestResponseComplete StreamRequestState = 1 - BodyRequestResponsesComplete StreamRequestState = 2 - TrailerRequestResponsesComplete StreamRequestState = 3 - ResponseRecieved StreamRequestState = 4 - HeaderResponseResponseComplete StreamRequestState = 5 - BodyResponseResponsesComplete StreamRequestState = 6 - TrailerResponseResponsesComplete StreamRequestState = 7 -) - // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling. func (s *StreamingServer) HandleRequestBody( ctx context.Context, - reqCtx *StreamingRequestContext, + reqCtx *RequestContext, req *extProcPb.ProcessingRequest, requestBodyMap map[string]interface{}, -) (*StreamingRequestContext, error) { +) (*RequestContext, error) { var requestBodyBytes []byte logger := log.FromContext(ctx) loggerVerbose := logger.V(logutil.VERBOSE) @@ -357,7 +337,7 @@ func (s *StreamingServer) HandleRequestBody( return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} } if len(modelObj.Spec.TargetModels) > 0 { - modelName = datastore.RandomWeightedDraw(logger, modelObj, 0) + modelName = RandomWeightedDraw(logger, modelObj, 0) if modelName == "" { return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} } @@ -405,63 +385,8 @@ func (s *StreamingServer) HandleRequestBody( reqCtx.TargetPod = targetPod.NamespacedName.String() reqCtx.TargetEndpoint = endpoint - headers := []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: s.destinationEndpointHintKey, - RawValue: []byte(endpoint), - }, - }, - // We need to update the content length header if the body is mutated, see Envoy doc: - // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto - { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte(strconv.Itoa(len(requestBodyBytes))), - }, - }, - } - // Print headers for debugging - for _, header := range headers { - logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue) - } + s.populateRequestHeaderResponse(ctx, reqCtx, endpoint, len(requestBodyBytes)) - targetEndpointValue := &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintKey: { - Kind: &structpb.Value_StringValue{ - StringValue: endpoint, - }, - }, - }, - } - dynamicMetadata := targetEndpointValue - if s.destinationEndpointHintMetadataNamespace != "" { - // If a namespace is defined, wrap the selected endpoint with that. - dynamicMetadata = &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintMetadataNamespace: { - Kind: &structpb.Value_StructValue{ - StructValue: targetEndpointValue, - }, - }, - }, - } - } - - reqCtx.reqHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - ClearRouteCache: true, - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: headers, - }, - }, - }, - }, - DynamicMetadata: dynamicMetadata, - } reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header // and as an unstructure ext-proc response metadata key/value pair. This enables different integration @@ -487,9 +412,9 @@ func (s *StreamingServer) HandleRequestBody( // HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling. func (s *StreamingServer) HandleResponseBody( ctx context.Context, - reqCtx *StreamingRequestContext, + reqCtx *RequestContext, response map[string]interface{}, -) (*StreamingRequestContext, error) { +) (*RequestContext, error) { logger := log.FromContext(ctx) loggerVerbose := logger.V(logutil.VERBOSE) loggerVerbose.Info("Processing HandleResponseBody") @@ -541,7 +466,7 @@ func (s *StreamingServer) HandleResponseBody( // The function is to handle streaming response if the modelServer is streaming. func (s *StreamingServer) HandleResponseBodyModelStreaming( ctx context.Context, - reqCtx *StreamingRequestContext, + reqCtx *RequestContext, responseText string, ) { logger := log.FromContext(ctx) @@ -554,3 +479,124 @@ func (s *StreamingServer) HandleResponseBodyModelStreaming( metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.CompletionTokens) } } + +func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error { + reqCtx.RequestReceivedTimestamp = time.Now() + + // an EoS in the request headers means this request has no body or trailers. + if req.RequestHeaders.EndOfStream { + // We will route this request to a random pod as this is assumed to just be a GET + // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526 + // The above PR will address endpoint admission, but currently any request without a body will be + // routed to a random upstream pod. + pod := GetRandomPod(s.datastore) + pool, err := s.datastore.PoolGet() + if err != nil { + return err + } + endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) + s.populateRequestHeaderResponse(ctx, reqCtx, endpoint, 0) + } + return nil +} + +func (s *StreamingServer) populateRequestHeaderResponse(ctx context.Context, reqCtx *RequestContext, endpoint string, requestBodyLength int) { + logger := log.FromContext(ctx) + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: s.destinationEndpointHintKey, + RawValue: []byte(endpoint), + }, + }, + } + if requestBodyLength > 0 { + // We need to update the content length header if the body is mutated, see Envoy doc: + // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(requestBodyLength)), + }, + }) + } + // Print headers for debugging + for _, header := range headers { + logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue) + } + + targetEndpointValue := &structpb.Struct{ + Fields: map[string]*structpb.Value{ + s.destinationEndpointHintKey: { + Kind: &structpb.Value_StringValue{ + StringValue: endpoint, + }, + }, + }, + } + dynamicMetadata := targetEndpointValue + if s.destinationEndpointHintMetadataNamespace != "" { + // If a namespace is defined, wrap the selected endpoint with that. + dynamicMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + s.destinationEndpointHintMetadataNamespace: { + Kind: &structpb.Value_StructValue{ + StructValue: targetEndpointValue, + }, + }, + }, + } + } + + reqCtx.reqHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, + }, + }, + }, + DynamicMetadata: dynamicMetadata, + } +} + +func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { + // TODO: after we are down to 1 server implementation, make these methods a part of the struct + // and handle random seeding on the struct. + source := rand.NewSource(rand.Int63()) + if seed > 0 { + source = rand.NewSource(seed) + } + r := rand.New(source) + + // all the weight values are nil, then we should return random model name + if model.Spec.TargetModels[0].Weight == nil { + index := r.Int31n(int32(len(model.Spec.TargetModels))) + return model.Spec.TargetModels[index].Name + } + + var weights int32 + for _, model := range model.Spec.TargetModels { + weights += *model.Weight + } + logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) + randomVal := r.Int31n(weights) + // TODO: optimize this without using loop + for _, model := range model.Spec.TargetModels { + if randomVal < *model.Weight { + return model.Name + } + randomVal -= *model.Weight + } + return "" +} + +func GetRandomPod(ds datastore.Datastore) *backendmetrics.Pod { + pods := ds.PodGetAll() + number := rand.Intn(len(pods)) + pod := pods[number] + return pod.GetPod() +} diff --git a/pkg/epp/handlers/streamingserver_test.go b/pkg/epp/handlers/streamingserver_test.go new file mode 100644 index 000000000..72f7031a4 --- /dev/null +++ b/pkg/epp/handlers/streamingserver_test.go @@ -0,0 +1,131 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package handlers + +import ( + "testing" + + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func TestRandomWeightedDraw(t *testing.T) { + logger := logutil.NewTestLogger() + tests := []struct { + name string + model *v1alpha2.InferenceModel + want string + }{ + { + name: "'random' distribution", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + Weight: pointer(50), + }, + { + Name: "v1", + Weight: pointer(50), + }, + }, + }, + }, + want: "canary", + }, + { + name: "'random' distribution", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + Weight: pointer(25), + }, + { + Name: "v1.1", + Weight: pointer(55), + }, + { + Name: "v1", + Weight: pointer(50), + }, + }, + }, + }, + want: "v1", + }, + { + name: "'random' distribution", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + Weight: pointer(20), + }, + { + Name: "v1.1", + Weight: pointer(20), + }, + { + Name: "v1", + Weight: pointer(10), + }, + }, + }, + }, + want: "v1.1", + }, + { + name: "weighted distribution with weight unset", + model: &v1alpha2.InferenceModel{ + Spec: v1alpha2.InferenceModelSpec{ + TargetModels: []v1alpha2.TargetModel{ + { + Name: "canary", + }, + { + Name: "v1.1", + }, + { + Name: "v1", + }, + }, + }, + }, + want: "canary", + }, + } + var seedVal int64 = 420 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for range 10000 { + model := RandomWeightedDraw(logger, test.model, seedVal) + if model != test.want { + t.Errorf("Model returned: %v != %v", model, test.want) + break + } + } + }) + } +} + +func pointer(v int32) *int32 { + return &v +} diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index cb18eaa4b..b12925eda 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -427,10 +427,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="my-model",target_model_name="my-model-12345"} 1 - `}, + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="my-model",target_model_name="my-model-12345"} 1 + `}, wantErr: false, wantResponses: []*extProcPb.ProcessingResponse{ { @@ -508,10 +508,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 - `}, + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 + `}, wantErr: false, wantResponses: []*extProcPb.ProcessingResponse{ { @@ -589,10 +589,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 - `}, + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 + `}, wantErr: false, wantResponses: []*extProcPb.ProcessingResponse{ { @@ -716,10 +716,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 - `}, + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 + `}, wantErr: false, wantResponses: []*extProcPb.ProcessingResponse{ { @@ -824,10 +824,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 - `}, + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 + `}, wantErr: false, wantResponses: []*extProcPb.ProcessingResponse{ { @@ -932,10 +932,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="direct-model",target_model_name="direct-model"} 1 - `}, + # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. + # TYPE inference_model_request_total counter + inference_model_request_total{model_name="direct-model",target_model_name="direct-model"} 1 + `}, wantErr: false, wantResponses: []*extProcPb.ProcessingResponse{ { @@ -1234,7 +1234,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { Request: &extProcPb.ProcessingRequest_ResponseBody{ ResponseBody: &extProcPb.HttpBody{ Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}} -data: [DONE]`, + data: [DONE]`, ), EndOfStream: false}, }, @@ -1249,31 +1249,31 @@ data: [DONE]`, }, wantErr: false, wantMetrics: map[string]string{`inference_model_input_tokens`: ` - # HELP inference_model_input_tokens [ALPHA] Inference model input token count distribution for requests in each model. - # TYPE inference_model_input_tokens histogram - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="1"} 0 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="8"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="16"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="32"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="64"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="128"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="256"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="512"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="1024"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="2048"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="4096"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="8192"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="16384"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="32778"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="65536"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="131072"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="262144"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="524288"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="1.048576e+06"} 1 - inference_model_input_tokens_bucket{model_name="",target_model_name="",le="+Inf"} 1 - inference_model_input_tokens_sum{model_name="",target_model_name=""} 7 - inference_model_input_tokens_count{model_name="",target_model_name=""} 1 - `}, + # HELP inference_model_input_tokens [ALPHA] Inference model input token count distribution for requests in each model. + # TYPE inference_model_input_tokens histogram + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="1"} 0 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="8"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="16"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="32"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="64"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="128"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="256"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="512"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="1024"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="2048"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="4096"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="8192"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="16384"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="32778"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="65536"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="131072"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="262144"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="524288"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="1.048576e+06"} 1 + inference_model_input_tokens_bucket{model_name="",target_model_name="",le="+Inf"} 1 + inference_model_input_tokens_sum{model_name="",target_model_name=""} 7 + inference_model_input_tokens_count{model_name="",target_model_name=""} 1 + `}, wantResponses: []*extProcPb.ProcessingResponse{ { Response: &extProcPb.ProcessingResponse_ResponseHeaders{ @@ -1381,7 +1381,7 @@ data: [DONE]`, Mutation: &extProcPb.BodyMutation_StreamedResponse{ StreamedResponse: &extProcPb.StreamedBodyResponse{ Body: []byte(`data: {"id":"cmpl-0fee233f-7d56-404a-acd3-4dad775d03d9","object":"text_completion","created":1741379018,"model":"tweet-summary-1","choices":[],"usage":{"prompt_tokens":7,"total_tokens":17,"completion_tokens":10}} -data: [DONE]`, + data: [DONE]`, ), EndOfStream: false, }, @@ -1409,6 +1409,63 @@ data: [DONE]`, }, }, }, + // Bodyless Request test + { + name: "simple GET Request", + requests: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "content-type", + RawValue: []byte("text/event-stream"), + }, + { + Key: "status", + RawValue: []byte("200"), + }, + }, + }, + EndOfStream: true, + }, + }, + }, + }, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "x-gateway-destination-endpoint", + RawValue: []byte("192.168.1.1:8000"), + }, + }, + }}, + }, + }, + }, + DynamicMetadata: makeMetadata("192.168.1.1:8000"), + }, + }, + pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + fakePod(0): { + WaitingQueueSize: 4, + KVCacheUsagePercent: 0.2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + "sql-lora-1fdg3": 1, + }, + }, + }, + }, } for _, test := range tests {