Skip to content

Commit 66b9889

Browse files
authored
EPP architectural refactor (#781)
1 parent b4cb728 commit 66b9889

File tree

8 files changed

+696
-574
lines changed

8 files changed

+696
-574
lines changed

pkg/epp/handlers/request.go

Lines changed: 102 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -18,100 +18,50 @@ package handlers
1818

1919
import (
2020
"context"
21-
"encoding/json"
22-
"fmt"
2321
"strconv"
2422
"time"
2523

24+
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
2625
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27-
"sigs.k8s.io/controller-runtime/pkg/log"
28-
"sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
29-
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26+
"google.golang.org/protobuf/types/known/structpb"
3027
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
31-
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3228
)
3329

34-
// HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
35-
func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
36-
logger := log.FromContext(ctx)
37-
38-
var requestBodyBytes []byte
39-
requestBodyMap := reqCtx.Request.Body
40-
// Resolve target models.
41-
model, ok := requestBodyMap["model"].(string)
42-
if !ok {
43-
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"}
44-
}
45-
prompt, ok := requestBodyMap["prompt"].(string)
46-
if !ok {
47-
return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"}
48-
}
49-
50-
modelName := model
30+
func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
31+
reqCtx.RequestReceivedTimestamp = time.Now()
5132

52-
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
53-
// This might be a security risk in the future where adapters not registered in the InferenceModel
54-
// are able to be requested by using their distinct name.
55-
modelObj := s.datastore.ModelGet(model)
56-
if modelObj == nil {
57-
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)}
58-
}
59-
if len(modelObj.Spec.TargetModels) > 0 {
60-
modelName = RandomWeightedDraw(logger, modelObj, 0)
61-
if modelName == "" {
62-
return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)}
33+
// an EoS in the request headers means this request has no body or trailers.
34+
if req.RequestHeaders.EndOfStream {
35+
// We will route this request to a random pod as this is assumed to just be a GET
36+
// More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
37+
// The above PR will address endpoint admission, but currently any request without a body will be
38+
// routed to a random upstream pod.
39+
pod := s.director.GetRandomPod()
40+
if pod == nil {
41+
return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"}
6342
}
43+
pool, err := s.datastore.PoolGet()
44+
if err != nil {
45+
return err
46+
}
47+
reqCtx.TargetEndpoint = pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
48+
reqCtx.RequestSize = 0
49+
reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx)
50+
return nil
6451
}
65-
llmReq := &schedulingtypes.LLMRequest{
66-
Model: model,
67-
ResolvedTargetModel: modelName,
68-
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
69-
Prompt: prompt,
70-
Headers: reqCtx.Request.Headers,
71-
}
72-
logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq)
73-
74-
var err error
75-
// Update target models in the body.
76-
if llmReq.Model != llmReq.ResolvedTargetModel {
77-
requestBodyMap["model"] = llmReq.ResolvedTargetModel
78-
}
79-
80-
requestBodyBytes, err = json.Marshal(requestBodyMap)
81-
if err != nil {
82-
logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body")
83-
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)}
84-
}
85-
86-
res, err := s.scheduler.Schedule(ctx, llmReq)
87-
if err != nil {
88-
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
89-
}
90-
targetPod := res.TargetPod.GetPod()
9152

92-
// Insert target endpoint to instruct Envoy to route requests to the specified target pod.
93-
// Attach the port number
94-
pool, err := s.datastore.PoolGet()
95-
if err != nil {
96-
return reqCtx, err
53+
for _, header := range req.RequestHeaders.Headers.Headers {
54+
if header.RawValue != nil {
55+
reqCtx.Request.Headers[header.Key] = string(header.RawValue)
56+
} else {
57+
reqCtx.Request.Headers[header.Key] = header.Value
58+
}
9759
}
98-
endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
99-
100-
logger.V(logutil.DEFAULT).Info("Request handled",
101-
"model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod)
102-
103-
reqCtx.Model = llmReq.Model
104-
reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel
105-
reqCtx.RequestSize = len(requestBodyBytes)
106-
reqCtx.TargetPod = targetPod.NamespacedName.String()
107-
reqCtx.TargetEndpoint = endpoint
108-
109-
s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes))
60+
return nil
61+
}
11062

111-
reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{
112-
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
113-
// and as an unstructure ext-proc response metadata key/value pair. This enables different integration
114-
// options for gateway providers.
63+
func (s *StreamingServer) generateRequestBodyResponse(requestBodyBytes []byte) *extProcPb.ProcessingResponse {
64+
return &extProcPb.ProcessingResponse{
11565
Response: &extProcPb.ProcessingResponse_RequestBody{
11666
RequestBody: &extProcPb.BodyResponse{
11767
Response: &extProcPb.CommonResponse{
@@ -127,37 +77,82 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request
12777
},
12878
},
12979
}
130-
return reqCtx, nil
13180
}
13281

133-
func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
134-
reqCtx.RequestReceivedTimestamp = time.Now()
82+
func (s *StreamingServer) generateRequestHeaderResponse(reqCtx *RequestContext) *extProcPb.ProcessingResponse {
83+
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
84+
// and as an unstructure ext-proc response metadata key/value pair. This enables different integration
85+
// options for gateway providers.
86+
return &extProcPb.ProcessingResponse{
87+
Response: &extProcPb.ProcessingResponse_RequestHeaders{
88+
RequestHeaders: &extProcPb.HeadersResponse{
89+
Response: &extProcPb.CommonResponse{
90+
ClearRouteCache: true,
91+
HeaderMutation: &extProcPb.HeaderMutation{
92+
SetHeaders: s.generateHeaders(reqCtx),
93+
},
94+
},
95+
},
96+
},
97+
DynamicMetadata: s.generateMetadata(reqCtx.TargetEndpoint),
98+
}
99+
}
135100

136-
// an EoS in the request headers means this request has no body or trailers.
137-
if req.RequestHeaders.EndOfStream {
138-
// We will route this request to a random pod as this is assumed to just be a GET
139-
// More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
140-
// The above PR will address endpoint admission, but currently any request without a body will be
141-
// routed to a random upstream pod.
142-
pod := GetRandomPod(s.datastore)
143-
if pod == nil {
144-
return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"}
145-
}
146-
pool, err := s.datastore.PoolGet()
147-
if err != nil {
148-
return err
149-
}
150-
endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber))
151-
s.populateRequestHeaderResponse(reqCtx, endpoint, 0)
152-
return nil
101+
func (s *StreamingServer) generateHeaders(reqCtx *RequestContext) []*configPb.HeaderValueOption {
102+
// can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic.
103+
headers := []*configPb.HeaderValueOption{
104+
{
105+
Header: &configPb.HeaderValue{
106+
Key: s.destinationEndpointHintKey,
107+
RawValue: []byte(reqCtx.TargetEndpoint),
108+
},
109+
},
110+
}
111+
if reqCtx.RequestSize > 0 {
112+
// We need to update the content length header if the body is mutated, see Envoy doc:
113+
// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
114+
headers = append(headers, &configPb.HeaderValueOption{
115+
Header: &configPb.HeaderValue{
116+
Key: "Content-Length",
117+
RawValue: []byte(strconv.Itoa(reqCtx.RequestSize)),
118+
},
119+
})
153120
}
154121

155-
for _, header := range req.RequestHeaders.Headers.Headers {
156-
if header.RawValue != nil {
157-
reqCtx.Request.Headers[header.Key] = string(header.RawValue)
158-
} else {
159-
reqCtx.Request.Headers[header.Key] = header.Value
122+
// include all headers
123+
for key, value := range reqCtx.Request.Headers {
124+
headers = append(headers, &configPb.HeaderValueOption{
125+
Header: &configPb.HeaderValue{
126+
Key: key,
127+
RawValue: []byte(value),
128+
},
129+
})
130+
}
131+
return headers
132+
}
133+
134+
func (s *StreamingServer) generateMetadata(endpoint string) *structpb.Struct {
135+
targetEndpointValue := &structpb.Struct{
136+
Fields: map[string]*structpb.Value{
137+
s.destinationEndpointHintKey: {
138+
Kind: &structpb.Value_StringValue{
139+
StringValue: endpoint,
140+
},
141+
},
142+
},
143+
}
144+
dynamicMetadata := targetEndpointValue
145+
if s.destinationEndpointHintMetadataNamespace != "" {
146+
// If a namespace is defined, wrap the selected endpoint with that.
147+
dynamicMetadata = &structpb.Struct{
148+
Fields: map[string]*structpb.Value{
149+
s.destinationEndpointHintMetadataNamespace: {
150+
Kind: &structpb.Value_StructValue{
151+
StructValue: targetEndpointValue,
152+
},
153+
},
154+
},
160155
}
161156
}
162-
return nil
157+
return dynamicMetadata
163158
}

0 commit comments

Comments
 (0)