-
Notifications
You must be signed in to change notification settings - Fork 333
/
Copy pathprefix_cache.go
109 lines (93 loc) · 3.27 KB
/
prefix_cache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
/*
Copyright 2024 The Aibrix Team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package routingalgorithms
import (
"context"
"fmt"
"math/rand"
"strconv"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/prefixcacheindexer"
"github.com/vllm-project/aibrix/pkg/utils"
v1 "k8s.io/api/core/v1"
"k8s.io/klog/v2"
)
const (
defaultPrefixCacheMatchThresholdPercent = 50
)
var (
prefixCacheMatchThresholdPercent = getPrefixCacheMatchThresholdPercent()
)
func getPrefixCacheMatchThresholdPercent() int {
value := utils.LoadEnv("AIBRIX_PREFIX_CACHE_MATCH_THRESHOLD_PERCENT", "")
if value != "" {
intValue, err := strconv.Atoi(value)
if err != nil || intValue <= 0 || intValue > 100 {
klog.Infof("invalid AIBRIX_PREFIX_CACHE_MATCH_THRESHOLD_PERCENT: %s, valid value between 0 and 100, failing back to default", value)
} else {
klog.Infof("using AIBRIX_PREFIX_CACHE_MATCH_THRESHOLD_PERCENT env value for prefix cache match threshold percent: %d", intValue)
return intValue
}
}
klog.Infof("using default prefix cache match threshold percent: %d", defaultPrefixCacheMatchThresholdPercent)
return defaultPrefixCacheMatchThresholdPercent
}
type prefixCacheRouter struct {
prefixCacheIndexer prefixcacheindexer.PrefixCacheIndexer
}
func NewPrefixCacheRouter() (Router, error) {
return prefixCacheRouter{
prefixCacheIndexer: prefixcacheindexer.NewPrefixHashTable(),
}, nil
}
func (p prefixCacheRouter) Route(ctx context.Context, pods map[string]*v1.Pod, model, message string) (string, error) {
readyPods := utils.FilterReadyPods(pods)
if len(readyPods) == 0 {
return "", fmt.Errorf("no pods to forward request")
}
if len(readyPods) == 1 {
for _, pod := range pods {
return getPodAddress(pod.Status.PodIP)
}
}
tokens, err := utils.TokenizeInputText(message)
if err != nil {
return "", err
}
var targetPod *v1.Pod
matchedTokens, unMatchedTokens, matchedPods := p.prefixCacheIndexer.MatchPrefix(tokens, model, readyPods)
if len(matchedTokens)*100/len(tokens) > prefixCacheMatchThresholdPercent {
targetPod = matchedPods[rand.Intn(len(matchedPods))]
} else {
// TODO: add better load balanced algorithms as fallback
targetPod = readyPods[rand.Intn(len(readyPods))]
}
if len(unMatchedTokens) > 0 {
p.prefixCacheIndexer.AddPrefix(unMatchedTokens, model, targetPod.Name)
}
var matchedPodNames, readyPodNames []string
for _, p := range matchedPods {
matchedPodNames = append(matchedPodNames, p.Status.PodIP)
}
for _, p := range readyPods {
readyPodNames = append(readyPodNames, p.Status.PodIP)
}
klog.InfoS("prefix cache route",
"message", message,
"tokens", tokens,
"matched_tokens", matchedTokens,
"unmatched_tokens", unMatchedTokens,
"matched_pods", matchedPodNames,
"ready_pods", readyPodNames,
"target_pod", targetPod.Status.PodIP)
return getPodAddress(targetPod.Status.PodIP)
}