@@ -68,7 +68,8 @@ async def main():
68
68
import logging
69
69
import warnings
70
70
from collections .abc import Awaitable , Callable
71
- from typing import Any , Sequence
71
+ from contextlib import AbstractAsyncContextManager , asynccontextmanager
72
+ from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
72
73
73
74
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
74
75
from pydantic import AnyUrl
@@ -101,13 +102,36 @@ def __init__(
101
102
self .tools_changed = tools_changed
102
103
103
104
104
- class Server :
105
+ LifespanResultT = TypeVar ("LifespanResultT" )
106
+
107
+
108
+ @asynccontextmanager
109
+ async def lifespan (server : "Server" ) -> AsyncIterator [object ]:
110
+ """Default lifespan context manager that does nothing.
111
+
112
+ Args:
113
+ server: The server instance this lifespan is managing
114
+
115
+ Returns:
116
+ An empty context object
117
+ """
118
+ yield {}
119
+
120
+
121
+ class Server (Generic [LifespanResultT ]):
105
122
def __init__ (
106
- self , name : str , version : str | None = None , instructions : str | None = None
123
+ self ,
124
+ name : str ,
125
+ version : str | None = None ,
126
+ instructions : str | None = None ,
127
+ lifespan : Callable [
128
+ ["Server" ], AbstractAsyncContextManager [LifespanResultT ]
129
+ ] = lifespan ,
107
130
):
108
131
self .name = name
109
132
self .version = version
110
133
self .instructions = instructions
134
+ self .lifespan = lifespan
111
135
self .request_handlers : dict [
112
136
type , Callable [..., Awaitable [types .ServerResult ]]
113
137
] = {
@@ -446,35 +470,43 @@ async def run(
446
470
raise_exceptions : bool = False ,
447
471
):
448
472
with warnings .catch_warnings (record = True ) as w :
449
- async with ServerSession (
450
- read_stream , write_stream , initialization_options
451
- ) as session :
452
- async for message in session .incoming_messages :
453
- logger .debug (f"Received message: { message } " )
454
-
455
- match message :
456
- case (
457
- RequestResponder (
458
- request = types .ClientRequest (root = req )
459
- ) as responder
460
- ):
461
- with responder :
462
- await self ._handle_request (
463
- message , req , session , raise_exceptions
464
- )
465
- case types .ClientNotification (root = notify ):
466
- await self ._handle_notification (notify )
467
-
468
- for warning in w :
469
- logger .info (
470
- f"Warning: { warning .category .__name__ } : { warning .message } "
471
- )
473
+ async with self .lifespan (self ) as lifespan_context :
474
+ async with ServerSession (
475
+ read_stream , write_stream , initialization_options
476
+ ) as session :
477
+ async for message in session .incoming_messages :
478
+ logger .debug (f"Received message: { message } " )
479
+
480
+ match message :
481
+ case (
482
+ RequestResponder (
483
+ request = types .ClientRequest (root = req )
484
+ ) as responder
485
+ ):
486
+ with responder :
487
+ await self ._handle_request (
488
+ message ,
489
+ req ,
490
+ session ,
491
+ lifespan_context ,
492
+ raise_exceptions ,
493
+ )
494
+ case types .ClientNotification (root = notify ):
495
+ await self ._handle_notification (notify )
496
+
497
+ for warning in w :
498
+ logger .info (
499
+ "Warning: %s: %s" ,
500
+ warning .category .__name__ ,
501
+ warning .message ,
502
+ )
472
503
473
504
async def _handle_request (
474
505
self ,
475
506
message : RequestResponder ,
476
507
req : Any ,
477
508
session : ServerSession ,
509
+ lifespan_context : object ,
478
510
raise_exceptions : bool ,
479
511
):
480
512
logger .info (f"Processing request of type { type (req ).__name__ } " )
@@ -491,6 +523,7 @@ async def _handle_request(
491
523
message .request_id ,
492
524
message .request_meta ,
493
525
session ,
526
+ lifespan_context ,
494
527
)
495
528
)
496
529
response = await handler (req )
0 commit comments