@@ -18,100 +18,50 @@ package handlers
18
18
19
19
import (
20
20
"context"
21
- "encoding/json"
22
- "fmt"
23
21
"strconv"
24
22
"time"
25
23
24
+ configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
26
25
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
27
- "sigs.k8s.io/controller-runtime/pkg/log"
28
- "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2"
29
- schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
26
+ "google.golang.org/protobuf/types/known/structpb"
30
27
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
31
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
32
28
)
33
29
34
- // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling.
35
- func (s * StreamingServer ) HandleRequestBody (ctx context.Context , reqCtx * RequestContext ) (* RequestContext , error ) {
36
- logger := log .FromContext (ctx )
37
-
38
- var requestBodyBytes []byte
39
- requestBodyMap := reqCtx .Request .Body
40
- // Resolve target models.
41
- model , ok := requestBodyMap ["model" ].(string )
42
- if ! ok {
43
- return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request" }
44
- }
45
- prompt , ok := requestBodyMap ["prompt" ].(string )
46
- if ! ok {
47
- return reqCtx , errutil.Error {Code : errutil .BadRequest , Msg : "prompt not found in request" }
48
- }
49
-
50
- modelName := model
30
+ func (s * StreamingServer ) HandleRequestHeaders (ctx context.Context , reqCtx * RequestContext , req * extProcPb.ProcessingRequest_RequestHeaders ) error {
31
+ reqCtx .RequestReceivedTimestamp = time .Now ()
51
32
52
- // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
53
- // This might be a security risk in the future where adapters not registered in the InferenceModel
54
- // are able to be requested by using their distinct name.
55
- modelObj := s .datastore .ModelGet (model )
56
- if modelObj == nil {
57
- return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error finding a model object in InferenceModel for input %v" , model )}
58
- }
59
- if len (modelObj .Spec .TargetModels ) > 0 {
60
- modelName = RandomWeightedDraw (logger , modelObj , 0 )
61
- if modelName == "" {
62
- return reqCtx , errutil.Error {Code : errutil .BadConfiguration , Msg : fmt .Sprintf ("error getting target model name for model %v" , modelObj .Name )}
33
+ // an EoS in the request headers means this request has no body or trailers.
34
+ if req .RequestHeaders .EndOfStream {
35
+ // We will route this request to a random pod as this is assumed to just be a GET
36
+ // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
37
+ // The above PR will address endpoint admission, but currently any request without a body will be
38
+ // routed to a random upstream pod.
39
+ pod := s .director .GetRandomPod ()
40
+ if pod == nil {
41
+ return errutil.Error {Code : errutil .Internal , Msg : "no pods available in datastore" }
63
42
}
43
+ pool , err := s .datastore .PoolGet ()
44
+ if err != nil {
45
+ return err
46
+ }
47
+ reqCtx .TargetEndpoint = pod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
48
+ reqCtx .RequestSize = 0
49
+ reqCtx .reqHeaderResp = s .generateRequestHeaderResponse (reqCtx )
50
+ return nil
64
51
}
65
- llmReq := & schedulingtypes.LLMRequest {
66
- Model : model ,
67
- ResolvedTargetModel : modelName ,
68
- Critical : modelObj .Spec .Criticality != nil && * modelObj .Spec .Criticality == v1alpha2 .Critical ,
69
- Prompt : prompt ,
70
- Headers : reqCtx .Request .Headers ,
71
- }
72
- logger .V (logutil .DEBUG ).Info ("LLM request assembled" , "request" , llmReq )
73
-
74
- var err error
75
- // Update target models in the body.
76
- if llmReq .Model != llmReq .ResolvedTargetModel {
77
- requestBodyMap ["model" ] = llmReq .ResolvedTargetModel
78
- }
79
-
80
- requestBodyBytes , err = json .Marshal (requestBodyMap )
81
- if err != nil {
82
- logger .V (logutil .DEFAULT ).Error (err , "Error marshaling request body" )
83
- return reqCtx , errutil.Error {Code : errutil .Internal , Msg : fmt .Sprintf ("error marshaling request body: %v" , err )}
84
- }
85
-
86
- res , err := s .scheduler .Schedule (ctx , llmReq )
87
- if err != nil {
88
- return reqCtx , errutil.Error {Code : errutil .InferencePoolResourceExhausted , Msg : fmt .Errorf ("failed to find target pod: %w" , err ).Error ()}
89
- }
90
- targetPod := res .TargetPod .GetPod ()
91
52
92
- // Insert target endpoint to instruct Envoy to route requests to the specified target pod.
93
- // Attach the port number
94
- pool , err := s .datastore .PoolGet ()
95
- if err != nil {
96
- return reqCtx , err
53
+ for _ , header := range req .RequestHeaders .Headers .Headers {
54
+ if header .RawValue != nil {
55
+ reqCtx .Request .Headers [header .Key ] = string (header .RawValue )
56
+ } else {
57
+ reqCtx .Request .Headers [header .Key ] = header .Value
58
+ }
97
59
}
98
- endpoint := targetPod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
99
-
100
- logger .V (logutil .DEFAULT ).Info ("Request handled" ,
101
- "model" , llmReq .Model , "targetModel" , llmReq .ResolvedTargetModel , "endpoint" , targetPod )
102
-
103
- reqCtx .Model = llmReq .Model
104
- reqCtx .ResolvedTargetModel = llmReq .ResolvedTargetModel
105
- reqCtx .RequestSize = len (requestBodyBytes )
106
- reqCtx .TargetPod = targetPod .NamespacedName .String ()
107
- reqCtx .TargetEndpoint = endpoint
108
-
109
- s .populateRequestHeaderResponse (reqCtx , endpoint , len (requestBodyBytes ))
60
+ return nil
61
+ }
110
62
111
- reqCtx .reqBodyResp = & extProcPb.ProcessingResponse {
112
- // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
113
- // and as an unstructure ext-proc response metadata key/value pair. This enables different integration
114
- // options for gateway providers.
63
+ func (s * StreamingServer ) generateRequestBodyResponse (requestBodyBytes []byte ) * extProcPb.ProcessingResponse {
64
+ return & extProcPb.ProcessingResponse {
115
65
Response : & extProcPb.ProcessingResponse_RequestBody {
116
66
RequestBody : & extProcPb.BodyResponse {
117
67
Response : & extProcPb.CommonResponse {
@@ -127,37 +77,82 @@ func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *Request
127
77
},
128
78
},
129
79
}
130
- return reqCtx , nil
131
80
}
132
81
133
- func (s * StreamingServer ) HandleRequestHeaders (ctx context.Context , reqCtx * RequestContext , req * extProcPb.ProcessingRequest_RequestHeaders ) error {
134
- reqCtx .RequestReceivedTimestamp = time .Now ()
82
+ func (s * StreamingServer ) generateRequestHeaderResponse (reqCtx * RequestContext ) * extProcPb.ProcessingResponse {
83
+ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
84
+ // and as an unstructure ext-proc response metadata key/value pair. This enables different integration
85
+ // options for gateway providers.
86
+ return & extProcPb.ProcessingResponse {
87
+ Response : & extProcPb.ProcessingResponse_RequestHeaders {
88
+ RequestHeaders : & extProcPb.HeadersResponse {
89
+ Response : & extProcPb.CommonResponse {
90
+ ClearRouteCache : true ,
91
+ HeaderMutation : & extProcPb.HeaderMutation {
92
+ SetHeaders : s .generateHeaders (reqCtx ),
93
+ },
94
+ },
95
+ },
96
+ },
97
+ DynamicMetadata : s .generateMetadata (reqCtx .TargetEndpoint ),
98
+ }
99
+ }
135
100
136
- // an EoS in the request headers means this request has no body or trailers.
137
- if req .RequestHeaders .EndOfStream {
138
- // We will route this request to a random pod as this is assumed to just be a GET
139
- // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526
140
- // The above PR will address endpoint admission, but currently any request without a body will be
141
- // routed to a random upstream pod.
142
- pod := GetRandomPod (s .datastore )
143
- if pod == nil {
144
- return errutil.Error {Code : errutil .Internal , Msg : "no pods available in datastore" }
145
- }
146
- pool , err := s .datastore .PoolGet ()
147
- if err != nil {
148
- return err
149
- }
150
- endpoint := pod .Address + ":" + strconv .Itoa (int (pool .Spec .TargetPortNumber ))
151
- s .populateRequestHeaderResponse (reqCtx , endpoint , 0 )
152
- return nil
101
+ func (s * StreamingServer ) generateHeaders (reqCtx * RequestContext ) []* configPb.HeaderValueOption {
102
+ // can likely refactor these two bespoke headers to be updated in PostDispatch, to centralize logic.
103
+ headers := []* configPb.HeaderValueOption {
104
+ {
105
+ Header : & configPb.HeaderValue {
106
+ Key : s .destinationEndpointHintKey ,
107
+ RawValue : []byte (reqCtx .TargetEndpoint ),
108
+ },
109
+ },
110
+ }
111
+ if reqCtx .RequestSize > 0 {
112
+ // We need to update the content length header if the body is mutated, see Envoy doc:
113
+ // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto
114
+ headers = append (headers , & configPb.HeaderValueOption {
115
+ Header : & configPb.HeaderValue {
116
+ Key : "Content-Length" ,
117
+ RawValue : []byte (strconv .Itoa (reqCtx .RequestSize )),
118
+ },
119
+ })
153
120
}
154
121
155
- for _ , header := range req .RequestHeaders .Headers .Headers {
156
- if header .RawValue != nil {
157
- reqCtx .Request .Headers [header .Key ] = string (header .RawValue )
158
- } else {
159
- reqCtx .Request .Headers [header .Key ] = header .Value
122
+ // include all headers
123
+ for key , value := range reqCtx .Request .Headers {
124
+ headers = append (headers , & configPb.HeaderValueOption {
125
+ Header : & configPb.HeaderValue {
126
+ Key : key ,
127
+ RawValue : []byte (value ),
128
+ },
129
+ })
130
+ }
131
+ return headers
132
+ }
133
+
134
+ func (s * StreamingServer ) generateMetadata (endpoint string ) * structpb.Struct {
135
+ targetEndpointValue := & structpb.Struct {
136
+ Fields : map [string ]* structpb.Value {
137
+ s .destinationEndpointHintKey : {
138
+ Kind : & structpb.Value_StringValue {
139
+ StringValue : endpoint ,
140
+ },
141
+ },
142
+ },
143
+ }
144
+ dynamicMetadata := targetEndpointValue
145
+ if s .destinationEndpointHintMetadataNamespace != "" {
146
+ // If a namespace is defined, wrap the selected endpoint with that.
147
+ dynamicMetadata = & structpb.Struct {
148
+ Fields : map [string ]* structpb.Value {
149
+ s .destinationEndpointHintMetadataNamespace : {
150
+ Kind : & structpb.Value_StructValue {
151
+ StructValue : targetEndpointValue ,
152
+ },
153
+ },
154
+ },
160
155
}
161
156
}
162
- return nil
157
+ return dynamicMetadata
163
158
}
0 commit comments