Skip to content

Commit 47e32bf

Browse files
authored
Merge pull request #137 from vMaroon/chat-completions
Add Support for OpenAI ChatCompletions API - PrefixAware Scoring
2 parents 5ee8225 + 4ddfc9f commit 47e32bf

File tree

7 files changed

+146
-9
lines changed

7 files changed

+146
-9
lines changed

Makefile

-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,6 @@ image-build: check-container-tool load-version-json ## Build container image usi
512512
--build-arg TARGETARCH=$(TARGETARCH) \
513513
--build-arg GIT_NM_USER=$(GIT_NM_USER)\
514514
--build-arg NM_TOKEN=$(NM_TOKEN) \
515-
--progress=plain \
516515
-t $(IMG) .
517516

518517
.PHONY: image-push

go.mod

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,19 @@ go 1.24.1
55
toolchain go1.24.2
66

77
require (
8+
github.com/cespare/xxhash/v2 v2.3.0
89
github.com/elastic/crd-ref-docs v0.1.0
910
github.com/envoyproxy/go-control-plane/envoy v1.32.4
1011
github.com/go-logr/logr v1.4.2
1112
github.com/google/go-cmp v0.7.0
13+
github.com/hashicorp/golang-lru/v2 v2.0.7
1214
github.com/neuralmagic/llm-d-kv-cache-manager v0.0.0-20250430102735-86595011431d
1315
github.com/onsi/ginkgo/v2 v2.23.4
1416
github.com/onsi/gomega v1.37.0
1517
github.com/prometheus/client_golang v1.22.0
1618
github.com/prometheus/client_model v0.6.2
1719
github.com/prometheus/common v0.63.0
20+
github.com/sashabaranov/go-openai v1.39.1
1821
github.com/stretchr/testify v1.10.0
1922
go.uber.org/multierr v1.11.0
2023
go.uber.org/zap v1.27.0
@@ -42,7 +45,6 @@ require (
4245
github.com/beorn7/perks v1.0.1 // indirect
4346
github.com/blang/semver/v4 v4.0.0 // indirect
4447
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
45-
github.com/cespare/xxhash/v2 v2.3.0 // indirect
4648
github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 // indirect
4749
github.com/daulet/tokenizers v1.20.2 // indirect
4850
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
@@ -74,7 +76,6 @@ require (
7476
github.com/google/uuid v1.6.0 // indirect
7577
github.com/gorilla/websocket v1.5.0 // indirect
7678
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
77-
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
7879
github.com/huandu/xstrings v1.3.3 // indirect
7980
github.com/imdario/mergo v0.3.11 // indirect
8081
github.com/inconshreveable/mousetrap v1.1.0 // indirect

go.sum

+2
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRl
189189
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
190190
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
191191
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
192+
github.com/sashabaranov/go-openai v1.39.1 h1:TMD4w77Iy9WTFlgnjNaxbAASdsCJ9R/rMdzL+SN14oU=
193+
github.com/sashabaranov/go-openai v1.39.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
192194
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
193195
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
194196
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

pkg/epp/handlers/request.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@ func (s *StreamingServer) HandleRequestBody(
7979
if llmReq.Model != llmReq.ResolvedTargetModel {
8080
requestBodyMap["model"] = llmReq.ResolvedTargetModel
8181
}
82-
// Extract prompt from the request body.
82+
// Extract prompt/messages from the request body.
8383
if prompt, ok := requestBodyMap["prompt"].(string); ok {
8484
llmReq.Prompt = prompt
85+
} else if _, ok := requestBodyMap["messages"]; ok { // check for chat completion request
86+
if chatRequest, err := schedulingtypes.NewKVCacheChatCompletionRequest(requestBodyMap); err == nil {
87+
llmReq.ChatCompletionRequest = chatRequest
88+
} else {
89+
logger.Error(err, "Error creating chat completion request")
90+
}
8591
}
8692

8793
requestBodyBytes, err = json.Marshal(requestBodyMap)

pkg/epp/scheduling/plugins/scorer/prefix_aware_scorer.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@ func (s *PrefixAwareScorer) Score(ctx *types.SchedulingContext, pods []types.Pod
5454
return nil
5555
}
5656

57-
scores := s.prefixStore.FindMatchingPods(ctx.Req.Prompt, ctx.Req.Model)
57+
var prompt string
58+
if ctx.Req.ChatCompletionRequest != nil {
59+
prompt = ctx.Req.ChatCompletionRequest.ToString()
60+
} else {
61+
prompt = ctx.Req.Prompt
62+
}
63+
64+
scores := s.prefixStore.FindMatchingPods(prompt, ctx.Req.Model)
5865
loggerDebug.Info("Got pod scores", "scores", scores)
5966

6067
if len(scores) == 0 {
@@ -92,7 +99,14 @@ func (s *PrefixAwareScorer) PostSchedule(ctx *types.SchedulingContext, res *type
9299
return
93100
}
94101

95-
if err := s.prefixStore.AddEntry(ctx.Req.Model, ctx.Req.Prompt, &pod.GetPod().NamespacedName); err != nil {
102+
var prompt string
103+
if ctx.Req.ChatCompletionRequest != nil {
104+
prompt = ctx.Req.ChatCompletionRequest.ToString()
105+
} else {
106+
prompt = ctx.Req.Prompt
107+
}
108+
109+
if err := s.prefixStore.AddEntry(ctx.Req.Model, prompt, &pod.GetPod().NamespacedName); err != nil {
96110
debugLogger.Error(err, "Failed to add entry to prefix store", "req", ctx.Req, "pod", pod)
97111
return
98112
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package types
18+
19+
import (
20+
"encoding/json"
21+
"strings"
22+
23+
"github.com/sashabaranov/go-openai"
24+
)
25+
26+
// KVCacheChatCompletionRequest is a struct that represents the fields from an
27+
// OpenAI API ChatCompletionRequest that are relevant for KV cache generation.
28+
// Model is not included as it is contained in the LLMRequest struct.
29+
//
30+
// Multimodal requests are not supported in the current implementation.
31+
type KVCacheChatCompletionRequest struct {
32+
Messages []openai.ChatCompletionMessage `json:"messages"`
33+
Tools []openai.Tool `json:"tools,omitempty"`
34+
ToolChoices []openai.ToolChoice `json:"tool_choices,omitempty"`
35+
}
36+
37+
// NewKVCacheChatCompletionRequest creates a new KVCacheChatCompletionRequest
38+
// from a json request.
39+
//
40+
// The call marshals the input map to JSON and then unmarshals it into the
41+
// KVCacheChatCompletionRequest struct.
42+
func NewKVCacheChatCompletionRequest(input map[string]interface{}) (*KVCacheChatCompletionRequest, error) {
43+
var req KVCacheChatCompletionRequest
44+
45+
if messagesRaw, ok := input["messages"]; ok {
46+
bytes, err := json.Marshal(messagesRaw)
47+
if err != nil {
48+
return nil, err
49+
}
50+
if err := json.Unmarshal(bytes, &req.Messages); err != nil {
51+
return nil, err
52+
}
53+
}
54+
55+
if toolsRaw, ok := input["tools"]; ok {
56+
bytes, err := json.Marshal(toolsRaw)
57+
if err != nil {
58+
return nil, err
59+
}
60+
if err := json.Unmarshal(bytes, &req.Tools); err != nil {
61+
return nil, err
62+
}
63+
}
64+
65+
if choicesRaw, ok := input["tool_choices"]; ok {
66+
bytes, err := json.Marshal(choicesRaw)
67+
if err != nil {
68+
return nil, err
69+
}
70+
if err := json.Unmarshal(bytes, &req.ToolChoices); err != nil {
71+
return nil, err
72+
}
73+
}
74+
75+
return &req, nil
76+
}
77+
78+
// ToString generates a string representation of the KVCacheChatCompletionRequest.
79+
func (r *KVCacheChatCompletionRequest) ToString() string {
80+
var builder strings.Builder
81+
82+
for _, msg := range r.Messages {
83+
builder.WriteString(msg.Role)
84+
builder.WriteString(":")
85+
builder.WriteString(msg.Content)
86+
builder.WriteString("\n")
87+
}
88+
89+
if len(r.Tools) > 0 {
90+
toolsJSON, err := json.Marshal(r.Tools)
91+
if err == nil {
92+
builder.WriteString("tools:")
93+
builder.Write(toolsJSON)
94+
builder.WriteString("\n")
95+
}
96+
}
97+
98+
if len(r.ToolChoices) > 0 {
99+
choicesJSON, err := json.Marshal(r.ToolChoices)
100+
if err == nil {
101+
builder.WriteString("tool_choices:")
102+
builder.Write(choicesJSON)
103+
builder.WriteString("\n")
104+
}
105+
}
106+
107+
return builder.String()
108+
}

pkg/epp/scheduling/types/types.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ import (
2727

2828
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
2929
type LLMRequest struct {
30-
Model string
31-
Prompt string
30+
Model string
31+
Prompt string
32+
ChatCompletionRequest *KVCacheChatCompletionRequest
3233
// Target models is a map of target model name to weight.
3334
TargetModels map[string]int
3435
Headers map[string]string
@@ -39,7 +40,13 @@ type LLMRequest struct {
3940
}
4041

4142
func (r *LLMRequest) String() string {
42-
return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, len(r.Prompt))
43+
promptLength := len(r.Prompt)
44+
if r.ChatCompletionRequest != nil {
45+
promptLength = len(r.ChatCompletionRequest.ToString())
46+
}
47+
48+
return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v",
49+
r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, promptLength)
4350
}
4451

4552
type Pod interface {

0 commit comments

Comments
 (0)