Skip to content

Consolidating down to FULL_DUPLEX_STREAMED supported ext-proc server #672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions cmd/epp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,6 @@ func run() error {
flag.Parse()
initLogging(&opts)

useStreamingServer, err := strconv.ParseBool(os.Getenv("USE_STREAMING"))
if err != nil {
setupLog.Error(err, "Failed to parse env var USE_STREAMING, defaulting to false")
}

// Validate flags
if err := validateFlags(); err != nil {
setupLog.Error(err, "Failed to validate flags")
Expand Down Expand Up @@ -178,7 +173,6 @@ func run() error {
Datastore: datastore,
SecureServing: *secureServing,
CertPath: *certPath,
UseStreaming: useStreamingServer,
RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval,
}
if err := serverRunner.SetupWithManager(ctx, mgr); err != nil {
Expand Down
3 changes: 0 additions & 3 deletions config/charts/inferencepool/templates/epp-deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ spec:
- "9003"
- -metricsPort
- "9090"
env:
- name: USE_STREAMING
value: "true"
ports:
- name: grpc
containerPort: 9002
Expand Down
3 changes: 0 additions & 3 deletions config/manifests/inferencepool-resources.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ spec:
- "9002"
- -grpcHealthPort
- "9003"
env:
- name: USE_STREAMING
value: "true"
ports:
- containerPort: 9002
- containerPort: 9003
Expand Down
162 changes: 51 additions & 111 deletions pkg/epp/handlers/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,190 +21,130 @@ import (
"encoding/json"
"fmt"
"strconv"
"time"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"google.golang.org/protobuf/types/known/structpb"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

// HandleRequestBody handles body of the request to the backend server, such as parsing the "model"
// parameter.
// Envoy sends the request body to ext proc before sending the request to the backend server.
func (s *Server) HandleRequestBody(
// HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
func (s *StreamingServer) HandleRequestBody(
ctx context.Context,
reqCtx *RequestContext,
req *extProcPb.ProcessingRequest,
) (*extProcPb.ProcessingResponse, error) {
requestBodyMap map[string]interface{},
) (*RequestContext, error) {
var requestBodyBytes []byte
logger := log.FromContext(ctx)
loggerVerbose := logger.V(logutil.VERBOSE)
loggerVerbose.Info("Handling request body")

// Unmarshal request body (must be JSON).
v := req.Request.(*extProcPb.ProcessingRequest_RequestBody)
var rb map[string]interface{}
if err := json.Unmarshal(v.RequestBody.Body, &rb); err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body")
return nil, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("error unmarshaling request body: %v", err)}
}
loggerVerbose.Info("Request body unmarshalled", "body", rb)

// Resolve target models.
model, ok := rb["model"].(string)
model, ok := requestBodyMap["model"].(string)
if !ok {
return nil, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
}
loggerVerbose.Info("Model requested", "model", model)

modelName := model

// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
// This might be a security risk in the future where adapters not registered in the InferenceModel
// are able to be requested by using their distinct name.
modelObj := s.datastore.ModelGet(model)
if modelObj == nil {
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
}
if len(modelObj.Spec.TargetModels) > 0 {
modelName = RandomWeightedDraw(logger, modelObj, 0)
if modelName == "" {
return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
}
}
llmReq := &schedulingtypes.LLMRequest{
Model: model,
ResolvedTargetModel: modelName,
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
}
loggerVerbose.Info("LLM request assembled", "request", llmReq)
logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical)

requestBody := v.RequestBody.Body
var err error
// Update target models in the body.
if llmReq.Model != llmReq.ResolvedTargetModel {
rb["model"] = llmReq.ResolvedTargetModel
requestBody, err = json.Marshal(rb)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
return nil, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
}
loggerVerbose.Info("Updated request body marshalled", "body", string(requestBody))
requestBodyMap["model"] = llmReq.ResolvedTargetModel
}

requestBodyBytes, err = json.Marshal(requestBodyMap)
if err != nil {
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
}

target, err := s.scheduler.Schedule(ctx, llmReq)
if err != nil {
return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}
targetPod := target.GetPod()

logger.V(logutil.DEFAULT).Info("Request handled",
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod)

// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
// Attach the port number
pool, err := s.datastore.PoolGet()
if err != nil {
return nil, err
return reqCtx, err
}
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))

logger.V(logutil.DEFAULT).Info("Request handled",
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics",
fmt.Sprintf("%+v", target))

reqCtx.Model = llmReq.Model
reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel
reqCtx.RequestSize = len(v.RequestBody.Body)
reqCtx.RequestSize = len(requestBodyBytes)
reqCtx.TargetPod = targetPod.NamespacedName.String()
reqCtx.TargetEndpoint = endpoint

headers := []*configPb.HeaderValueOption{
{
Header: &configPb.HeaderValue{
Key: s.destinationEndpointHintKey,
RawValue: []byte(endpoint),
},
},
// We need to update the content length header if the body is mutated, see Envoy doc:
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
{
Header: &configPb.HeaderValue{
Key: "Content-Length",
RawValue: []byte(strconv.Itoa(len(requestBody))),
},
},
}
// Print headers for debugging
for _, header := range headers {
logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue)
}

targetEndpointValue := &structpb.Struct{
Fields: map[string]*structpb.Value{
s.destinationEndpointHintKey: {
Kind: &structpb.Value_StringValue{
StringValue: endpoint,
},
},
},
}
dynamicMetadata := targetEndpointValue
if s.destinationEndpointHintMetadataNamespace != "" {
// If a namespace is defined, wrap the selected endpoint with that.
dynamicMetadata = &structpb.Struct{
Fields: map[string]*structpb.Value{
s.destinationEndpointHintMetadataNamespace: {
Kind: &structpb.Value_StructValue{
StructValue: targetEndpointValue,
},
},
},
}
}
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes))

resp := &extProcPb.ProcessingResponse{
reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
// and as an unstructure ext-proc response metadata key/value pair. This enables different integration
// options for gateway providers.
Response: &extProcPb.ProcessingResponse_RequestBody{
RequestBody: &extProcPb.BodyResponse{
Response: &extProcPb.CommonResponse{
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: headers,
},
BodyMutation: &extProcPb.BodyMutation{
Mutation: &extProcPb.BodyMutation_Body{
Body: requestBody,
Mutation: &extProcPb.BodyMutation_StreamedResponse{
StreamedResponse: &extProcPb.StreamedBodyResponse{
Body: requestBodyBytes,
EndOfStream: true,
},
},
},
},
},
},
DynamicMetadata: dynamicMetadata,
}
return resp, nil
return reqCtx, nil
}

func HandleRequestHeaders(
ctx context.Context,
reqCtx *RequestContext,
req *extProcPb.ProcessingRequest,
) *extProcPb.ProcessingResponse {
r := req.Request
h := r.(*extProcPb.ProcessingRequest_RequestHeaders)
log.FromContext(ctx).V(logutil.VERBOSE).Info("Handling request headers", "headers", h)

resp := &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
// Set `clear_route_cache = true` to force Envoy to recompute the target cluster
// based on the new "target-pod" header.
// See https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto#service-ext-proc-v3-commonresponse.
ClearRouteCache: true,
},
},
},
func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
reqCtx.RequestReceivedTimestamp = time.Now()

// an EoS in the request headers means this request has no body or trailers.
if req.RequestHeaders.EndOfStream {
// We will route this request to a random pod as this is assumed to just be a GET
// More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
// The above PR will address endpoint admission, but currently any request without a body will be
// routed to a random upstream pod.
pod := GetRandomPod(s.datastore)
pool, err := s.datastore.PoolGet()
if err != nil {
return err
}
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
}

return resp
return nil
}
Loading