Skip to content

Commit 100f636

Browse files
committed
Make reconciling non blocking
1 parent bea4068 commit 100f636

File tree

3 files changed

+98
-37
lines changed

3 files changed

+98
-37
lines changed

examples/dynamic-lora-sidecar/Dockerfile

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
FROM python:3.9-slim-buster AS test
2+
3+
WORKDIR /dynamic-lora-reconciler-test
4+
COPY requirements.txt .
5+
COPY sidecar/* .
6+
RUN pip install -r requirements.txt
7+
RUN python -m unittest discover || exit 1
18

29
FROM python:3.10-slim-buster
310

examples/dynamic-lora-sidecar/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ aiohttp
22
jsonschema
33
pyyaml
44
requests
5-
watchfiles
5+
watchfiles
6+
watchdog

examples/dynamic-lora-sidecar/sidecar/sidecar.py

+89-36
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,51 @@
33
import time
44
from jsonschema import validate
55
from watchfiles import awatch
6-
import ipaddress
76
from dataclasses import dataclass
87
import asyncio
98
import logging
109
import datetime
1110
import os
11+
import sys
12+
from watchdog.observers.polling import PollingObserver as Observer
13+
from watchdog.events import FileSystemEventHandler
1214

13-
CONFIG_MAP_FILE = os.environ.get("DYNAMIC_LORA_ROLLOUT_CONFIG", "/config/configmap.yaml")
15+
CONFIG_MAP_FILE = os.environ.get(
16+
"DYNAMIC_LORA_ROLLOUT_CONFIG", "/config/configmap.yaml"
17+
)
1418
BASE_FIELD = "vLLMLoRAConfig"
1519
logging.basicConfig(
16-
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
17-
datefmt='%Y-%m-%d %H:%M:%S',
18-
handlers=[logging.StreamHandler()]
20+
level=logging.INFO,
21+
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
22+
datefmt="%Y-%m-%d %H:%M:%S",
23+
handlers=[logging.StreamHandler(sys.stdout)],
1924
)
20-
logging.Formatter.converter = time.localtime
25+
logging.Formatter.converter = time.localtime
2126

2227

2328
def current_time_human() -> str:
2429
now = datetime.datetime.now(datetime.timezone.utc).astimezone()
2530
return now.strftime("%Y-%m-%d %H:%M:%S %Z%z")
2631

32+
33+
class FileChangeHandler(FileSystemEventHandler):
34+
"""Custom event handler that handles file modifications."""
35+
36+
def __init__(self, reconciler):
37+
super().__init__()
38+
self.reconciler = reconciler
39+
40+
def on_modified(self, event):
41+
logging.info("modified!")
42+
logging.info(f"Config '{CONFIG_MAP_FILE}' modified!")
43+
self.reconciler.reconcile()
44+
logging.info(f"model server reconcile to Config '{CONFIG_MAP_FILE}' !")
45+
46+
2747
@dataclass
2848
class LoraAdapter:
2949
"""Class representation of lora adapters in config"""
50+
3051
def __init__(self, id, source="", base_model=""):
3152
self.id = id
3253
self.source = source
@@ -48,34 +69,33 @@ def __init__(self, config_validation=True):
4869
self.health_check_timeout = datetime.timedelta(seconds=300)
4970
self.health_check_interval = datetime.timedelta(seconds=15)
5071
self.config_validation = config_validation
51-
52-
def validate_config(self, c)-> bool:
72+
73+
def validate_config(self, c) -> bool:
5374
try:
54-
with open('validation.yaml', 'r') as f:
75+
with open("validation.yaml", "r") as f:
5576
schema = yaml.safe_load(f)
5677
validate(instance=c, schema=schema)
5778
return True
5879
except Exception as e:
5980
logging.error(f"Cannot load config {CONFIG_MAP_FILE} validation error: {e}")
6081
return False
61-
82+
6283
@property
6384
def config(self):
6485
"""Load configmap into memory"""
6586
try:
66-
6787
with open(CONFIG_MAP_FILE, "r") as f:
6888
c = yaml.safe_load(f)
6989
if self.config_validation and not self.validate_config(c):
7090
return {}
7191
if c is None:
7292
c = {}
73-
c = c.get("vLLMLoRAConfig",{})
93+
c = c.get("vLLMLoRAConfig", {})
7494
return c
7595
except Exception as e:
7696
logging.error(f"cannot load config {CONFIG_MAP_FILE} {e}")
7797
return {}
78-
98+
7999
@property
80100
def host(self):
81101
"""Model server host"""
@@ -85,7 +105,7 @@ def host(self):
85105
def port(self):
86106
"""Model server port"""
87107
return self.config.get("port", 8000)
88-
108+
89109
@property
90110
def model_server(self):
91111
"""Model server {host}:{port}"""
@@ -95,13 +115,27 @@ def model_server(self):
95115
def ensure_exist_adapters(self):
96116
"""Lora adapters in config under key `ensureExist` in set"""
97117
adapters = self.config.get("ensureExist", {}).get("models", set())
98-
return set([LoraAdapter(adapter["id"], adapter["source"], adapter.get("base-model","")) for adapter in adapters])
118+
return set(
119+
[
120+
LoraAdapter(
121+
adapter["id"], adapter["source"], adapter.get("base-model", "")
122+
)
123+
for adapter in adapters
124+
]
125+
)
99126

100127
@property
101128
def ensure_not_exist_adapters(self):
102129
"""Lora adapters in config under key `ensureNotExist` in set"""
103130
adapters = self.config.get("ensureNotExist", {}).get("models", set())
104-
return set([LoraAdapter(adapter["id"], adapter["source"], adapter.get("base-model","")) for adapter in adapters])
131+
return set(
132+
[
133+
LoraAdapter(
134+
adapter["id"], adapter["source"], adapter.get("base-model", "")
135+
)
136+
for adapter in adapters
137+
]
138+
)
105139

106140
@property
107141
def registered_adapters(self):
@@ -123,7 +157,7 @@ def registered_adapters(self):
123157
@property
124158
def is_server_healthy(self) -> bool:
125159
"""probe server's health endpoint until timeout or success"""
126-
160+
127161
def check_health() -> bool:
128162
"""Checks server health"""
129163
url = f"http://{self.model_server}/health"
@@ -132,24 +166,26 @@ def check_health() -> bool:
132166
return response.status_code == 200
133167
except requests.exceptions.RequestException:
134168
return False
135-
169+
136170
start_time = datetime.datetime.now()
137171
while datetime.datetime.now() - start_time < self.health_check_timeout:
138172
if check_health():
139173
return True
140174
time.sleep(self.health_check_interval.seconds)
141175
return False
142-
176+
143177
def load_adapter(self, adapter: LoraAdapter):
144178
"""Sends a request to load the specified model."""
145179
if adapter in self.registered_adapters:
146-
logging.info(f"{adapter.id} already present on model server {self.model_server}")
180+
logging.info(
181+
f"{adapter.id} already present on model server {self.model_server}"
182+
)
147183
return
148184
url = f"http://{self.model_server}/v1/load_lora_adapter"
149185
payload = {
150186
"lora_name": adapter.id,
151187
"lora_path": adapter.source,
152-
"base_model_name": adapter.base_model
188+
"base_model_name": adapter.base_model,
153189
}
154190
try:
155191
response = requests.post(url, json=payload)
@@ -161,7 +197,9 @@ def load_adapter(self, adapter: LoraAdapter):
161197
def unload_adapter(self, adapter: LoraAdapter):
162198
"""Sends a request to unload the specified model."""
163199
if adapter not in self.registered_adapters:
164-
logging.info(f"{adapter.id} already doesn't exist on model server {self.model_server}")
200+
logging.info(
201+
f"{adapter.id} already doesn't exist on model server {self.model_server}"
202+
)
165203
return
166204
url = f"http://{self.model_server}/v1/unload_lora_adapter"
167205
payload = {"lora_name": adapter.id}
@@ -176,12 +214,19 @@ def unload_adapter(self, adapter: LoraAdapter):
176214

177215
def reconcile(self):
178216
"""Reconciles model server with current version of configmap"""
179-
logging.info(f"reconciling model server {self.model_server} with config stored at {CONFIG_MAP_FILE}")
217+
logging.info(
218+
f"reconciling model server {self.model_server} with config stored at {CONFIG_MAP_FILE}"
219+
)
180220
if not self.is_server_healthy:
181221
logging.error(f"vllm server at {self.model_server} not healthy")
182222
return
183-
invalid_adapters = ", ".join(str(a.id) for a in self.ensure_exist_adapters & self.ensure_not_exist_adapters)
184-
logging.warning(f"skipped adapters found in both `ensureExist` and `ensureNotExist` {invalid_adapters}")
223+
invalid_adapters = ", ".join(
224+
str(a.id)
225+
for a in self.ensure_exist_adapters & self.ensure_not_exist_adapters
226+
)
227+
logging.warning(
228+
f"skipped adapters found in both `ensureExist` and `ensureNotExist` {invalid_adapters}"
229+
)
185230
adapters_to_load = self.ensure_exist_adapters - self.ensure_not_exist_adapters
186231
adapters_to_load_id = ", ".join(str(a.id) for a in adapters_to_load)
187232
logging.info(f"adapter to load {adapters_to_load_id}")
@@ -194,18 +239,26 @@ def reconcile(self):
194239
self.unload_adapter(adapter)
195240

196241

197-
198242
async def main():
199-
"""Loads the target configuration, compares it with the server's models,
200-
and loads/unloads models accordingly."""
201-
202-
reconcilerInstance = LoraReconciler()
203-
logging.info(f"running reconcile for initial loading of configmap {CONFIG_MAP_FILE}")
204-
reconcilerInstance.reconcile()
205-
logging.info(f"beginning watching of configmap {CONFIG_MAP_FILE}")
206-
async for _ in awatch('/config/configmap.yaml'):
207-
logging.info(f"Config '{CONFIG_MAP_FILE}' modified!'" )
208-
reconcilerInstance.reconcile()
243+
reconciler_instance = LoraReconciler()
244+
logging.info(f"Running initial reconcile for config map {CONFIG_MAP_FILE}")
245+
reconciler_instance.reconcile()
246+
247+
event_handler = FileChangeHandler(reconciler_instance)
248+
observer = Observer()
249+
observer.schedule(
250+
event_handler, path=os.path.dirname(CONFIG_MAP_FILE), recursive=False
251+
)
252+
observer.start()
253+
254+
try:
255+
logging.info(f"Starting to watch {CONFIG_MAP_FILE} for changes...")
256+
while True:
257+
await asyncio.sleep(1)
258+
except KeyboardInterrupt:
259+
logging.info("Stopped by user.")
260+
observer.stop()
261+
observer.join()
209262

210263

211264
if __name__ == "__main__":

0 commit comments

Comments
 (0)