Skip to content

Commit 16ded66

Browse files
authored
update algorithm parameters from env variables (#580)
* update algorithm parameters form env variables * move env parsers to a new pkg in utils * add unit test for env parser * remove logging env variables during scheduling * add test for env parser
1 parent 4182265 commit 16ded66

File tree

5 files changed

+254
-23
lines changed

5 files changed

+254
-23
lines changed

pkg/epp/scheduling/filter.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendm
141141
}
142142

143143
func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool {
144-
return pod.GetMetrics().WaitingQueueSize < queueingThresholdLoRA
144+
return pod.GetMetrics().WaitingQueueSize < config.QueueingThresholdLoRA
145145
}
146146

147147
// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range
@@ -223,7 +223,7 @@ func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendm
223223

224224
// If both groups have pods, use probability to select which group to return
225225
if len(filtered_affinity) > 0 && len(filtered_available) > 0 {
226-
if randGen.Float64() < loraAffinityThreshold {
226+
if randGen.Float64() < config.LoraAffinityThreshold {
227227
return filtered_affinity, nil
228228
}
229229
return filtered_available, nil

pkg/epp/scheduling/filter_test.go

+20-6
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,18 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
442442
tolerancePercent = 5.0 // Allow 5% tolerance from expected distribution
443443
)
444444

445+
// Save original config value to restore later
446+
originalThreshold := config.LoraAffinityThreshold
447+
448+
// Set a specific test value for this test
449+
testThreshold := 0.75 // 75%
450+
config.LoraAffinityThreshold = testThreshold
451+
452+
// Ensure we restore the original threshold when test completes
453+
defer func() {
454+
config.LoraAffinityThreshold = originalThreshold
455+
}()
456+
445457
// Create a test request and pods
446458
req := &LLMRequest{
447459
Model: testAffinityModel,
@@ -472,9 +484,10 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
472484
affinityCount := 0
473485
availableCount := 0
474486

475-
// Use the actual loraAffinityThreshold as defined in the original code
476-
// This test should work with whatever value is set there
477-
expectedAffinityPercent := loraAffinityThreshold * 100
487+
// Use the test threshold value
488+
expectedAffinityPercent := config.LoraAffinityThreshold * 100
489+
expectedAvailabilityPercent := 100 - expectedAffinityPercent
490+
478491
for i := 0; i < numIterations; i++ {
479492
result, err := loRASoftAffinityFilter(logger, req, toInterface(pods))
480493
if err != nil {
@@ -502,11 +515,12 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
502515
affinityLowerBound := expectedAffinityPercent - tolerancePercent
503516
affinityUpperBound := expectedAffinityPercent + tolerancePercent
504517

505-
availableLowerBound := actualAvailablePercent - tolerancePercent
506-
availableUpperBound := actualAvailablePercent + tolerancePercent
518+
availableLowerBound := expectedAvailabilityPercent - tolerancePercent
519+
availableUpperBound := expectedAvailabilityPercent + tolerancePercent
507520

508521
t.Logf("Distribution results over %d iterations:", numIterations)
509-
t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, loraAffinityThreshold)
522+
t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.LoraAffinityThreshold)
523+
t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.LoraAffinityThreshold)
510524
t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations)
511525
t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations)
512526

pkg/epp/scheduling/scheduler.go

+37-15
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,46 @@ import (
2626
"sigs.k8s.io/controller-runtime/pkg/log"
2727
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
2828
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
29+
envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env"
2930
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
3031
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3132
)
3233

34+
// Config holds all the configuration values for the scheduler
35+
type Config struct {
36+
KVCacheThreshold float64
37+
QueueThresholdCritical int
38+
QueueingThresholdLoRA int
39+
LoraAffinityThreshold float64
40+
}
41+
3342
const (
34-
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
35-
kvCacheThreshold = 0.8
36-
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
37-
queueThresholdCritical = 5
38-
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
39-
// the threshold for queued requests to be considered low below which we can prioritize LoRA affinity.
40-
// The value of 128 is arrived heuristicically based on experiments.
41-
queueingThresholdLoRA = 128
42-
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
43-
// loraAffinityThreshold indicates the probability with which we prefer a pod with LoRA affinity over a pod without but having room to fit more LoRA adapters.
44-
loraAffinityThreshold = 0.999
43+
// Default values to use if environment variables are not set
44+
defaultKVCacheThreshold = 0.8
45+
defaultQueueThresholdCritical = 5
46+
defaultQueueingThresholdLoRA = 128
47+
defaultLoraAffinityThreshold = 0.999
4548
)
4649

50+
// LoadConfig loads configuration from environment variables
51+
func LoadConfig() Config {
52+
// Use a default logger for initial configuration loading
53+
baseLogger := log.Log.WithName("scheduling-config")
54+
55+
config := Config{
56+
KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger),
57+
QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger),
58+
QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger),
59+
LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger),
60+
}
61+
62+
baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config)
63+
64+
return config
65+
}
66+
67+
var config = LoadConfig()
68+
4769
var (
4870
defaultFilter = &filter{
4971
name: "critical request",
@@ -92,7 +114,7 @@ var (
92114
// cache below a certain threshold, we consider this model server has capacity to handle
93115
// a sheddable request without impacting critical requests.
94116
name: "has capacity for sheddable requests",
95-
filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(queueThresholdCritical, kvCacheThreshold)),
117+
filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(config.QueueThresholdCritical, config.KVCacheThreshold)),
96118
nextOnSuccess: queueLoRAAndKVCacheFilter,
97119
// If all pods are queuing or running above the KVCache threshold, we drop the sheddable
98120
// request to make room for critical requests.
@@ -123,13 +145,13 @@ type Scheduler struct {
123145
// Schedule finds the target pod based on metrics and the requested lora adapter.
124146
func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backendmetrics.PodMetrics, err error) {
125147
logger := log.FromContext(ctx).WithValues("request", req)
126-
podMetrics := s.datastore.PodGetAll()
127148

149+
podMetrics := s.datastore.PodGetAll()
128150
logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", podMetrics))
151+
129152
pods, err := s.filter.Filter(logger, req, podMetrics)
130153
if err != nil || len(pods) == 0 {
131-
return nil, fmt.Errorf(
132-
"failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err)
154+
return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err)
133155
}
134156
logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods))
135157
i := rand.Intn(len(pods))

pkg/epp/util/env/env.go

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package env
2+
3+
import (
4+
"os"
5+
"strconv"
6+
7+
"github.com/go-logr/logr"
8+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
9+
)
10+
11+
// getEnvFloat gets a float64 from an environment variable with a default value
12+
func GetEnvFloat(key string, defaultVal float64, logger logr.Logger) float64 {
13+
val, exists := os.LookupEnv(key)
14+
if !exists {
15+
logger.V(logutil.VERBOSE).Info("Environment variable not set, using default value",
16+
"key", key, "defaultValue", defaultVal)
17+
return defaultVal
18+
}
19+
20+
floatVal, err := strconv.ParseFloat(val, 64)
21+
if err != nil {
22+
logger.V(logutil.VERBOSE).Info("Failed to parse environment variable as float, using default value",
23+
"key", key, "value", val, "error", err, "defaultValue", defaultVal)
24+
return defaultVal
25+
}
26+
27+
logger.V(logutil.VERBOSE).Info("Successfully loaded environment variable",
28+
"key", key, "value", floatVal)
29+
return floatVal
30+
}
31+
32+
// getEnvInt gets an int from an environment variable with a default value
33+
func GetEnvInt(key string, defaultVal int, logger logr.Logger) int {
34+
val, exists := os.LookupEnv(key)
35+
if !exists {
36+
logger.V(logutil.VERBOSE).Info("Environment variable not set, using default value",
37+
"key", key, "defaultValue", defaultVal)
38+
return defaultVal
39+
}
40+
41+
intVal, err := strconv.Atoi(val)
42+
if err != nil {
43+
logger.V(logutil.VERBOSE).Info("Failed to parse environment variable as int, using default value",
44+
"key", key, "value", val, "error", err, "defaultValue", defaultVal)
45+
return defaultVal
46+
}
47+
48+
logger.V(logutil.VERBOSE).Info("Successfully loaded environment variable",
49+
"key", key, "value", intVal)
50+
return intVal
51+
}

pkg/epp/util/env/env_test.go

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package env
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/go-logr/logr/testr"
8+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
9+
)
10+
11+
func TestGetEnvFloat(t *testing.T) {
12+
logger := testr.New(t)
13+
14+
tests := []struct {
15+
name string
16+
key string
17+
value string
18+
defaultVal float64
19+
expected float64
20+
setup func()
21+
teardown func()
22+
}{
23+
{
24+
name: "env variable exists and is valid",
25+
key: "TEST_FLOAT",
26+
value: "123.456",
27+
defaultVal: 0.0,
28+
expected: 123.456,
29+
setup: func() {
30+
os.Setenv("TEST_FLOAT", "123.456")
31+
},
32+
teardown: func() {
33+
os.Unsetenv("TEST_FLOAT")
34+
},
35+
},
36+
{
37+
name: "env variable exists but is invalid",
38+
key: "TEST_FLOAT",
39+
value: "invalid",
40+
defaultVal: 99.9,
41+
expected: 99.9,
42+
setup: func() {
43+
os.Setenv("TEST_FLOAT", "invalid")
44+
},
45+
teardown: func() {
46+
os.Unsetenv("TEST_FLOAT")
47+
},
48+
},
49+
{
50+
name: "env variable does not exist",
51+
key: "TEST_FLOAT_MISSING",
52+
defaultVal: 42.42,
53+
expected: 42.42,
54+
setup: func() {},
55+
teardown: func() {},
56+
},
57+
}
58+
59+
for _, tc := range tests {
60+
t.Run(tc.name, func(t *testing.T) {
61+
tc.setup()
62+
defer tc.teardown()
63+
64+
result := GetEnvFloat(tc.key, tc.defaultVal, logger.V(logutil.VERBOSE))
65+
if result != tc.expected {
66+
t.Errorf("GetEnvFloat(%s, %f) = %f, expected %f", tc.key, tc.defaultVal, result, tc.expected)
67+
}
68+
})
69+
}
70+
}
71+
72+
func TestGetEnvInt(t *testing.T) {
73+
logger := testr.New(t)
74+
75+
tests := []struct {
76+
name string
77+
key string
78+
value string
79+
defaultVal int
80+
expected int
81+
setup func()
82+
teardown func()
83+
}{
84+
{
85+
name: "env variable exists and is valid",
86+
key: "TEST_INT",
87+
value: "123",
88+
defaultVal: 0,
89+
expected: 123,
90+
setup: func() {
91+
os.Setenv("TEST_INT", "123")
92+
},
93+
teardown: func() {
94+
os.Unsetenv("TEST_INT")
95+
},
96+
},
97+
{
98+
name: "env variable exists but is invalid",
99+
key: "TEST_INT",
100+
value: "invalid",
101+
defaultVal: 99,
102+
expected: 99,
103+
setup: func() {
104+
os.Setenv("TEST_INT", "invalid")
105+
},
106+
teardown: func() {
107+
os.Unsetenv("TEST_INT")
108+
},
109+
},
110+
{
111+
name: "env variable does not exist",
112+
key: "TEST_INT_MISSING",
113+
defaultVal: 42,
114+
expected: 42,
115+
setup: func() {},
116+
teardown: func() {},
117+
},
118+
{
119+
name: "env variable is empty string",
120+
key: "TEST_INT_EMPTY",
121+
value: "",
122+
defaultVal: 77,
123+
expected: 77,
124+
setup: func() {
125+
os.Setenv("TEST_INT_EMPTY", "")
126+
},
127+
teardown: func() {
128+
os.Unsetenv("TEST_INT_EMPTY")
129+
},
130+
},
131+
}
132+
133+
for _, tc := range tests {
134+
t.Run(tc.name, func(t *testing.T) {
135+
tc.setup()
136+
defer tc.teardown()
137+
138+
result := GetEnvInt(tc.key, tc.defaultVal, logger.V(logutil.VERBOSE))
139+
if result != tc.expected {
140+
t.Errorf("GetEnvInt(%s, %d) = %d, expected %d", tc.key, tc.defaultVal, result, tc.expected)
141+
}
142+
})
143+
}
144+
}

0 commit comments

Comments
 (0)