diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/filter.go index cee683c5..f4848089 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/filter.go @@ -141,7 +141,7 @@ func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendm } func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool { - return pod.GetMetrics().WaitingQueueSize < queueingThresholdLoRA + return pod.GetMetrics().WaitingQueueSize < config.QueueingThresholdLoRA } // leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range @@ -223,7 +223,7 @@ func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendm // If both groups have pods, use probability to select which group to return if len(filtered_affinity) > 0 && len(filtered_available) > 0 { - if randGen.Float64() < loraAffinityThreshold { + if randGen.Float64() < config.LoraAffinityThreshold { return filtered_affinity, nil } return filtered_available, nil diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/filter_test.go index 62ffe7f2..127e6c21 100644 --- a/pkg/epp/scheduling/filter_test.go +++ b/pkg/epp/scheduling/filter_test.go @@ -442,6 +442,18 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { tolerancePercent = 5.0 // Allow 5% tolerance from expected distribution ) + // Save original config value to restore later + originalThreshold := config.LoraAffinityThreshold + + // Set a specific test value for this test + testThreshold := 0.75 // 75% + config.LoraAffinityThreshold = testThreshold + + // Ensure we restore the original threshold when test completes + defer func() { + config.LoraAffinityThreshold = originalThreshold + }() + // Create a test request and pods req := &LLMRequest{ Model: testAffinityModel, @@ -472,9 +484,10 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { affinityCount := 0 availableCount := 0 - // Use the actual loraAffinityThreshold as defined in the original code - // This test should work with whatever value is set there - expectedAffinityPercent := loraAffinityThreshold * 100 + // Use the test threshold value + expectedAffinityPercent := config.LoraAffinityThreshold * 100 + expectedAvailabilityPercent := 100 - expectedAffinityPercent + for i := 0; i < numIterations; i++ { result, err := loRASoftAffinityFilter(logger, req, toInterface(pods)) if err != nil { @@ -502,11 +515,12 @@ func TestLoRASoftAffinityDistribution(t *testing.T) { affinityLowerBound := expectedAffinityPercent - tolerancePercent affinityUpperBound := expectedAffinityPercent + tolerancePercent - availableLowerBound := actualAvailablePercent - tolerancePercent - availableUpperBound := actualAvailablePercent + tolerancePercent + availableLowerBound := expectedAvailabilityPercent - tolerancePercent + availableUpperBound := expectedAvailabilityPercent + tolerancePercent t.Logf("Distribution results over %d iterations:", numIterations) - t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, loraAffinityThreshold) + t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.LoraAffinityThreshold) + t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.LoraAffinityThreshold) t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations) t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations) diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index 63d829a1..e874724d 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -26,24 +26,46 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +// Config holds all the configuration values for the scheduler +type Config struct { + KVCacheThreshold float64 + QueueThresholdCritical int + QueueingThresholdLoRA int + LoraAffinityThreshold float64 +} + const ( - // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable. - kvCacheThreshold = 0.8 - // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable. - queueThresholdCritical = 5 - // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable. - // the threshold for queued requests to be considered low below which we can prioritize LoRA affinity. - // The value of 128 is arrived heuristicically based on experiments. - queueingThresholdLoRA = 128 - // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable. - // loraAffinityThreshold indicates the probability with which we prefer a pod with LoRA affinity over a pod without but having room to fit more LoRA adapters. - loraAffinityThreshold = 0.999 + // Default values to use if environment variables are not set + defaultKVCacheThreshold = 0.8 + defaultQueueThresholdCritical = 5 + defaultQueueingThresholdLoRA = 128 + defaultLoraAffinityThreshold = 0.999 ) +// LoadConfig loads configuration from environment variables +func LoadConfig() Config { + // Use a default logger for initial configuration loading + baseLogger := log.Log.WithName("scheduling-config") + + config := Config{ + KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), + QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), + QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), + LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), + } + + baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) + + return config +} + +var config = LoadConfig() + var ( defaultFilter = &filter{ name: "critical request", @@ -92,7 +114,7 @@ var ( // cache below a certain threshold, we consider this model server has capacity to handle // a sheddable request without impacting critical requests. name: "has capacity for sheddable requests", - filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(queueThresholdCritical, kvCacheThreshold)), + filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(config.QueueThresholdCritical, config.KVCacheThreshold)), nextOnSuccess: queueLoRAAndKVCacheFilter, // If all pods are queuing or running above the KVCache threshold, we drop the sheddable // request to make room for critical requests. @@ -123,13 +145,13 @@ type Scheduler struct { // Schedule finds the target pod based on metrics and the requested lora adapter. func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backendmetrics.PodMetrics, err error) { logger := log.FromContext(ctx).WithValues("request", req) - podMetrics := s.datastore.PodGetAll() + podMetrics := s.datastore.PodGetAll() logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", podMetrics)) + pods, err := s.filter.Filter(logger, req, podMetrics) if err != nil || len(pods) == 0 { - return nil, fmt.Errorf( - "failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) + return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) } logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) i := rand.Intn(len(pods)) diff --git a/pkg/epp/util/env/env.go b/pkg/epp/util/env/env.go new file mode 100644 index 00000000..11e3bde1 --- /dev/null +++ b/pkg/epp/util/env/env.go @@ -0,0 +1,51 @@ +package env + +import ( + "os" + "strconv" + + "github.com/go-logr/logr" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// getEnvFloat gets a float64 from an environment variable with a default value +func GetEnvFloat(key string, defaultVal float64, logger logr.Logger) float64 { + val, exists := os.LookupEnv(key) + if !exists { + logger.V(logutil.VERBOSE).Info("Environment variable not set, using default value", + "key", key, "defaultValue", defaultVal) + return defaultVal + } + + floatVal, err := strconv.ParseFloat(val, 64) + if err != nil { + logger.V(logutil.VERBOSE).Info("Failed to parse environment variable as float, using default value", + "key", key, "value", val, "error", err, "defaultValue", defaultVal) + return defaultVal + } + + logger.V(logutil.VERBOSE).Info("Successfully loaded environment variable", + "key", key, "value", floatVal) + return floatVal +} + +// getEnvInt gets an int from an environment variable with a default value +func GetEnvInt(key string, defaultVal int, logger logr.Logger) int { + val, exists := os.LookupEnv(key) + if !exists { + logger.V(logutil.VERBOSE).Info("Environment variable not set, using default value", + "key", key, "defaultValue", defaultVal) + return defaultVal + } + + intVal, err := strconv.Atoi(val) + if err != nil { + logger.V(logutil.VERBOSE).Info("Failed to parse environment variable as int, using default value", + "key", key, "value", val, "error", err, "defaultValue", defaultVal) + return defaultVal + } + + logger.V(logutil.VERBOSE).Info("Successfully loaded environment variable", + "key", key, "value", intVal) + return intVal +} diff --git a/pkg/epp/util/env/env_test.go b/pkg/epp/util/env/env_test.go new file mode 100644 index 00000000..02513e28 --- /dev/null +++ b/pkg/epp/util/env/env_test.go @@ -0,0 +1,144 @@ +package env + +import ( + "os" + "testing" + + "github.com/go-logr/logr/testr" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +func TestGetEnvFloat(t *testing.T) { + logger := testr.New(t) + + tests := []struct { + name string + key string + value string + defaultVal float64 + expected float64 + setup func() + teardown func() + }{ + { + name: "env variable exists and is valid", + key: "TEST_FLOAT", + value: "123.456", + defaultVal: 0.0, + expected: 123.456, + setup: func() { + os.Setenv("TEST_FLOAT", "123.456") + }, + teardown: func() { + os.Unsetenv("TEST_FLOAT") + }, + }, + { + name: "env variable exists but is invalid", + key: "TEST_FLOAT", + value: "invalid", + defaultVal: 99.9, + expected: 99.9, + setup: func() { + os.Setenv("TEST_FLOAT", "invalid") + }, + teardown: func() { + os.Unsetenv("TEST_FLOAT") + }, + }, + { + name: "env variable does not exist", + key: "TEST_FLOAT_MISSING", + defaultVal: 42.42, + expected: 42.42, + setup: func() {}, + teardown: func() {}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.setup() + defer tc.teardown() + + result := GetEnvFloat(tc.key, tc.defaultVal, logger.V(logutil.VERBOSE)) + if result != tc.expected { + t.Errorf("GetEnvFloat(%s, %f) = %f, expected %f", tc.key, tc.defaultVal, result, tc.expected) + } + }) + } +} + +func TestGetEnvInt(t *testing.T) { + logger := testr.New(t) + + tests := []struct { + name string + key string + value string + defaultVal int + expected int + setup func() + teardown func() + }{ + { + name: "env variable exists and is valid", + key: "TEST_INT", + value: "123", + defaultVal: 0, + expected: 123, + setup: func() { + os.Setenv("TEST_INT", "123") + }, + teardown: func() { + os.Unsetenv("TEST_INT") + }, + }, + { + name: "env variable exists but is invalid", + key: "TEST_INT", + value: "invalid", + defaultVal: 99, + expected: 99, + setup: func() { + os.Setenv("TEST_INT", "invalid") + }, + teardown: func() { + os.Unsetenv("TEST_INT") + }, + }, + { + name: "env variable does not exist", + key: "TEST_INT_MISSING", + defaultVal: 42, + expected: 42, + setup: func() {}, + teardown: func() {}, + }, + { + name: "env variable is empty string", + key: "TEST_INT_EMPTY", + value: "", + defaultVal: 77, + expected: 77, + setup: func() { + os.Setenv("TEST_INT_EMPTY", "") + }, + teardown: func() { + os.Unsetenv("TEST_INT_EMPTY") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.setup() + defer tc.teardown() + + result := GetEnvInt(tc.key, tc.defaultVal, logger.V(logutil.VERBOSE)) + if result != tc.expected { + t.Errorf("GetEnvInt(%s, %d) = %d, expected %d", tc.key, tc.defaultVal, result, tc.expected) + } + }) + } +}