diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index 65d082c8c..5495c1da9 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -32,14 +32,11 @@ import ( ) // HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling. -func (s *StreamingServer) HandleRequestBody( - ctx context.Context, - reqCtx *RequestContext, -) (*RequestContext, error) { - var requestBodyBytes []byte +func (s *StreamingServer) HandleRequestBody(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) { logger := log.FromContext(ctx) - requestBodyMap := reqCtx.Request.Body + var requestBodyBytes []byte + requestBodyMap := reqCtx.Request.Body // Resolve target models. model, ok := requestBodyMap["model"].(string) if !ok { @@ -70,6 +67,7 @@ func (s *StreamingServer) HandleRequestBody( ResolvedTargetModel: modelName, Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, Prompt: prompt, + Headers: reqCtx.Request.Headers, } logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 0d11ebf91..c02d9f56c 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -18,7 +18,7 @@ package scheduling import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" -// SchedulerConfig creates a new SchedulerConfig object with the given plugins. +// NewSchedulerConfig creates a new SchedulerConfig object with the given plugins. func NewSchedulerConfig(preSchedulePlugins []plugins.PreSchedule, filters []plugins.Filter, scorers map[plugins.Scorer]int, picker plugins.Picker, postSchedulePlugins []plugins.PostSchedule) *SchedulerConfig { return &SchedulerConfig{ diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go index 4f69fae0a..795ef65d2 100644 --- a/pkg/epp/scheduling/types/types.go +++ b/pkg/epp/scheduling/types/types.go @@ -28,17 +28,21 @@ import ( // LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. type LLMRequest struct { + // Model is the name of the model that the user specified in the request body. Model string - // Target models is a map of target model name to weight. - TargetModels map[string]int - Prompt string - // Resolved target model is the final target model after traffic split. + // ResolvedTargetModel is the final target model after traffic split. ResolvedTargetModel string - Critical bool + // Critical is a boolean that specifies if a request is critical or not. + Critical bool + // Prompt is the prompt that was sent in the request body. + Prompt string + // Headers is a map of the request headers. + Headers map[string]string } func (r *LLMRequest) String() string { - return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, len(r.Prompt)) + return fmt.Sprintf("Model: %s, ResolvedTargetModel: %s, Critical: %t, PromptLength: %d, Headers: %v", + r.Model, r.ResolvedTargetModel, r.Critical, len(r.Prompt), r.Headers) } type Pod interface {