3
3
import time
4
4
from jsonschema import validate
5
5
from watchfiles import awatch
6
- import ipaddress
7
6
from dataclasses import dataclass
8
7
import asyncio
9
8
import logging
10
9
import datetime
11
10
import os
11
+ import sys
12
+ from watchdog .observers .polling import PollingObserver as Observer
13
+ from watchdog .events import FileSystemEventHandler
12
14
13
- CONFIG_MAP_FILE = os .environ .get ("DYNAMIC_LORA_ROLLOUT_CONFIG" , "/config/configmap.yaml" )
15
+ CONFIG_MAP_FILE = os .environ .get (
16
+ "DYNAMIC_LORA_ROLLOUT_CONFIG" , "/config/configmap.yaml"
17
+ )
14
18
BASE_FIELD = "vLLMLoRAConfig"
15
19
logging .basicConfig (
16
- level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" ,
17
- datefmt = '%Y-%m-%d %H:%M:%S' ,
18
- handlers = [logging .StreamHandler ()]
20
+ level = logging .INFO ,
21
+ format = "%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" ,
22
+ datefmt = "%Y-%m-%d %H:%M:%S" ,
23
+ handlers = [logging .StreamHandler (sys .stdout )],
19
24
)
20
- logging .Formatter .converter = time .localtime
25
+ logging .Formatter .converter = time .localtime
21
26
22
27
23
28
def current_time_human () -> str :
24
29
now = datetime .datetime .now (datetime .timezone .utc ).astimezone ()
25
30
return now .strftime ("%Y-%m-%d %H:%M:%S %Z%z" )
26
31
32
+
33
+ class FileChangeHandler (FileSystemEventHandler ):
34
+ """Custom event handler that handles file modifications."""
35
+
36
+ def __init__ (self , reconciler ):
37
+ super ().__init__ ()
38
+ self .reconciler = reconciler
39
+
40
+ def on_modified (self , event ):
41
+ logging .info ("modified!" )
42
+ logging .info (f"Config '{ CONFIG_MAP_FILE } ' modified!" )
43
+ self .reconciler .reconcile ()
44
+ logging .info (f"model server reconcile to Config '{ CONFIG_MAP_FILE } ' !" )
45
+
46
+
27
47
@dataclass
28
48
class LoraAdapter :
29
49
"""Class representation of lora adapters in config"""
50
+
30
51
def __init__ (self , id , source = "" , base_model = "" ):
31
52
self .id = id
32
53
self .source = source
@@ -48,34 +69,33 @@ def __init__(self, config_validation=True):
48
69
self .health_check_timeout = datetime .timedelta (seconds = 300 )
49
70
self .health_check_interval = datetime .timedelta (seconds = 15 )
50
71
self .config_validation = config_validation
51
-
52
- def validate_config (self , c )-> bool :
72
+
73
+ def validate_config (self , c ) -> bool :
53
74
try :
54
- with open (' validation.yaml' , 'r' ) as f :
75
+ with open (" validation.yaml" , "r" ) as f :
55
76
schema = yaml .safe_load (f )
56
77
validate (instance = c , schema = schema )
57
78
return True
58
79
except Exception as e :
59
80
logging .error (f"Cannot load config { CONFIG_MAP_FILE } validation error: { e } " )
60
81
return False
61
-
82
+
62
83
@property
63
84
def config (self ):
64
85
"""Load configmap into memory"""
65
86
try :
66
-
67
87
with open (CONFIG_MAP_FILE , "r" ) as f :
68
88
c = yaml .safe_load (f )
69
89
if self .config_validation and not self .validate_config (c ):
70
90
return {}
71
91
if c is None :
72
92
c = {}
73
- c = c .get ("vLLMLoRAConfig" ,{})
93
+ c = c .get ("vLLMLoRAConfig" , {})
74
94
return c
75
95
except Exception as e :
76
96
logging .error (f"cannot load config { CONFIG_MAP_FILE } { e } " )
77
97
return {}
78
-
98
+
79
99
@property
80
100
def host (self ):
81
101
"""Model server host"""
@@ -85,7 +105,7 @@ def host(self):
85
105
def port (self ):
86
106
"""Model server port"""
87
107
return self .config .get ("port" , 8000 )
88
-
108
+
89
109
@property
90
110
def model_server (self ):
91
111
"""Model server {host}:{port}"""
@@ -95,13 +115,27 @@ def model_server(self):
95
115
def ensure_exist_adapters (self ):
96
116
"""Lora adapters in config under key `ensureExist` in set"""
97
117
adapters = self .config .get ("ensureExist" , {}).get ("models" , set ())
98
- return set ([LoraAdapter (adapter ["id" ], adapter ["source" ], adapter .get ("base-model" ,"" )) for adapter in adapters ])
118
+ return set (
119
+ [
120
+ LoraAdapter (
121
+ adapter ["id" ], adapter ["source" ], adapter .get ("base-model" , "" )
122
+ )
123
+ for adapter in adapters
124
+ ]
125
+ )
99
126
100
127
@property
101
128
def ensure_not_exist_adapters (self ):
102
129
"""Lora adapters in config under key `ensureNotExist` in set"""
103
130
adapters = self .config .get ("ensureNotExist" , {}).get ("models" , set ())
104
- return set ([LoraAdapter (adapter ["id" ], adapter ["source" ], adapter .get ("base-model" ,"" )) for adapter in adapters ])
131
+ return set (
132
+ [
133
+ LoraAdapter (
134
+ adapter ["id" ], adapter ["source" ], adapter .get ("base-model" , "" )
135
+ )
136
+ for adapter in adapters
137
+ ]
138
+ )
105
139
106
140
@property
107
141
def registered_adapters (self ):
@@ -123,7 +157,7 @@ def registered_adapters(self):
123
157
@property
124
158
def is_server_healthy (self ) -> bool :
125
159
"""probe server's health endpoint until timeout or success"""
126
-
160
+
127
161
def check_health () -> bool :
128
162
"""Checks server health"""
129
163
url = f"http://{ self .model_server } /health"
@@ -132,24 +166,26 @@ def check_health() -> bool:
132
166
return response .status_code == 200
133
167
except requests .exceptions .RequestException :
134
168
return False
135
-
169
+
136
170
start_time = datetime .datetime .now ()
137
171
while datetime .datetime .now () - start_time < self .health_check_timeout :
138
172
if check_health ():
139
173
return True
140
174
time .sleep (self .health_check_interval .seconds )
141
175
return False
142
-
176
+
143
177
def load_adapter (self , adapter : LoraAdapter ):
144
178
"""Sends a request to load the specified model."""
145
179
if adapter in self .registered_adapters :
146
- logging .info (f"{ adapter .id } already present on model server { self .model_server } " )
180
+ logging .info (
181
+ f"{ adapter .id } already present on model server { self .model_server } "
182
+ )
147
183
return
148
184
url = f"http://{ self .model_server } /v1/load_lora_adapter"
149
185
payload = {
150
186
"lora_name" : adapter .id ,
151
187
"lora_path" : adapter .source ,
152
- "base_model_name" : adapter .base_model
188
+ "base_model_name" : adapter .base_model ,
153
189
}
154
190
try :
155
191
response = requests .post (url , json = payload )
@@ -161,7 +197,9 @@ def load_adapter(self, adapter: LoraAdapter):
161
197
def unload_adapter (self , adapter : LoraAdapter ):
162
198
"""Sends a request to unload the specified model."""
163
199
if adapter not in self .registered_adapters :
164
- logging .info (f"{ adapter .id } already doesn't exist on model server { self .model_server } " )
200
+ logging .info (
201
+ f"{ adapter .id } already doesn't exist on model server { self .model_server } "
202
+ )
165
203
return
166
204
url = f"http://{ self .model_server } /v1/unload_lora_adapter"
167
205
payload = {"lora_name" : adapter .id }
@@ -176,12 +214,19 @@ def unload_adapter(self, adapter: LoraAdapter):
176
214
177
215
def reconcile (self ):
178
216
"""Reconciles model server with current version of configmap"""
179
- logging .info (f"reconciling model server { self .model_server } with config stored at { CONFIG_MAP_FILE } " )
217
+ logging .info (
218
+ f"reconciling model server { self .model_server } with config stored at { CONFIG_MAP_FILE } "
219
+ )
180
220
if not self .is_server_healthy :
181
221
logging .error (f"vllm server at { self .model_server } not healthy" )
182
222
return
183
- invalid_adapters = ", " .join (str (a .id ) for a in self .ensure_exist_adapters & self .ensure_not_exist_adapters )
184
- logging .warning (f"skipped adapters found in both `ensureExist` and `ensureNotExist` { invalid_adapters } " )
223
+ invalid_adapters = ", " .join (
224
+ str (a .id )
225
+ for a in self .ensure_exist_adapters & self .ensure_not_exist_adapters
226
+ )
227
+ logging .warning (
228
+ f"skipped adapters found in both `ensureExist` and `ensureNotExist` { invalid_adapters } "
229
+ )
185
230
adapters_to_load = self .ensure_exist_adapters - self .ensure_not_exist_adapters
186
231
adapters_to_load_id = ", " .join (str (a .id ) for a in adapters_to_load )
187
232
logging .info (f"adapter to load { adapters_to_load_id } " )
@@ -194,18 +239,26 @@ def reconcile(self):
194
239
self .unload_adapter (adapter )
195
240
196
241
197
-
198
242
async def main ():
199
- """Loads the target configuration, compares it with the server's models,
200
- and loads/unloads models accordingly."""
201
-
202
- reconcilerInstance = LoraReconciler ()
203
- logging .info (f"running reconcile for initial loading of configmap { CONFIG_MAP_FILE } " )
204
- reconcilerInstance .reconcile ()
205
- logging .info (f"beginning watching of configmap { CONFIG_MAP_FILE } " )
206
- async for _ in awatch ('/config/configmap.yaml' ):
207
- logging .info (f"Config '{ CONFIG_MAP_FILE } ' modified!'" )
208
- reconcilerInstance .reconcile ()
243
+ reconciler_instance = LoraReconciler ()
244
+ logging .info (f"Running initial reconcile for config map { CONFIG_MAP_FILE } " )
245
+ reconciler_instance .reconcile ()
246
+
247
+ event_handler = FileChangeHandler (reconciler_instance )
248
+ observer = Observer ()
249
+ observer .schedule (
250
+ event_handler , path = os .path .dirname (CONFIG_MAP_FILE ), recursive = False
251
+ )
252
+ observer .start ()
253
+
254
+ try :
255
+ logging .info (f"Starting to watch { CONFIG_MAP_FILE } for changes..." )
256
+ while True :
257
+ await asyncio .sleep (1 )
258
+ except KeyboardInterrupt :
259
+ logging .info ("Stopped by user." )
260
+ observer .stop ()
261
+ observer .join ()
209
262
210
263
211
264
if __name__ == "__main__" :
0 commit comments