Skip to content

Commit 61bdee7

Browse files
Fix tiny-agents cli exit issues (#3125)
* first draft * Proper exit events * Update src/huggingface_hub/inference/_mcp/_cli_hacks.py Co-authored-by: célina <[email protected]> * Update src/huggingface_hub/inference/_mcp/_cli_hacks.py Co-authored-by: célina <[email protected]> * Update src/huggingface_hub/inference/_mcp/_cli_hacks.py Co-authored-by: célina <[email protected]> * make style * comment * exit_event required --------- Co-authored-by: célina <[email protected]>
1 parent a212a63 commit 61bdee7

File tree

3 files changed

+105
-14
lines changed

3 files changed

+105
-14
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import asyncio
2+
import sys
3+
from functools import partial
4+
5+
import typer
6+
7+
8+
def _patch_anyio_open_process():
9+
"""
10+
Patch anyio.open_process to allow detached processes on Windows and Unix-like systems.
11+
12+
This is necessary to prevent the MCP client from being interrupted by Ctrl+C when running in the CLI.
13+
"""
14+
import subprocess
15+
16+
import anyio
17+
18+
if getattr(anyio, "_tiny_agents_patched", False):
19+
return
20+
anyio._tiny_agents_patched = True
21+
22+
original_open_process = anyio.open_process
23+
24+
if sys.platform == "win32":
25+
# On Windows, we need to set the creation flags to create a new process group
26+
27+
async def open_process_in_new_group(*args, **kwargs):
28+
"""
29+
Wrapper for open_process to handle Windows-specific process creation flags.
30+
"""
31+
# Ensure we pass the creation flags for Windows
32+
kwargs.setdefault("creationflags", subprocess.CREATE_NEW_PROCESS_GROUP)
33+
return await original_open_process(*args, **kwargs)
34+
35+
anyio.open_process = open_process_in_new_group
36+
else:
37+
# For Unix-like systems, we can use setsid to create a new session
38+
async def open_process_in_new_group(*args, **kwargs):
39+
"""
40+
Wrapper for open_process to handle Unix-like systems with start_new_session=True.
41+
"""
42+
kwargs.setdefault("start_new_session", True)
43+
return await original_open_process(*args, **kwargs)
44+
45+
anyio.open_process = open_process_in_new_group
46+
47+
48+
async def _async_prompt(exit_event: asyncio.Event, prompt: str = "» ") -> str:
49+
"""
50+
Asynchronous prompt function that reads input from stdin without blocking.
51+
52+
This function is designed to work in an asynchronous context, allowing the event loop to gracefully stop it (e.g. on Ctrl+C).
53+
54+
Alternatively, we could use https://github.com/vxgmichel/aioconsole but that would be an additional dependency.
55+
"""
56+
loop = asyncio.get_event_loop()
57+
58+
if sys.platform == "win32":
59+
# Windows: Use run_in_executor to avoid blocking the event loop
60+
# Degraded solution: this is not ideal as user will have to CTRL+C once more to stop the prompt (and it'll not be graceful)
61+
return await loop.run_in_executor(None, partial(typer.prompt, prompt, prompt_suffix=" "))
62+
else:
63+
# UNIX-like: Use loop.add_reader for non-blocking stdin read
64+
future = loop.create_future()
65+
66+
def on_input():
67+
line = sys.stdin.readline()
68+
loop.remove_reader(sys.stdin)
69+
future.set_result(line)
70+
71+
print(prompt, end=" ", flush=True)
72+
loop.add_reader(sys.stdin, on_input) # not supported on Windows
73+
74+
# Wait for user input or exit event
75+
# Wait until either the user hits enter or exit_event is set
76+
await asyncio.wait(
77+
[future, exit_event.wait()],
78+
return_when=asyncio.FIRST_COMPLETED,
79+
)
80+
81+
# Check which one has been triggered
82+
if exit_event.is_set():
83+
future.cancel()
84+
return ""
85+
86+
line = await future
87+
return line.strip()

src/huggingface_hub/inference/_mcp/cli.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
2-
import os
32
import signal
43
import traceback
5-
from functools import partial
64
from typing import Any, Dict, List, Optional
75

86
import typer
97
from rich import print
108

9+
from ._cli_hacks import _async_prompt, _patch_anyio_open_process
1110
from .agent import Agent
1211
from .utils import _load_agent_config
1312

@@ -25,11 +24,6 @@
2524
app.add_typer(run_cli, name="run")
2625

2726

28-
async def _ainput(prompt: str = "» ") -> str:
29-
loop = asyncio.get_running_loop()
30-
return await loop.run_in_executor(None, partial(typer.prompt, prompt, prompt_suffix=" "))
31-
32-
3327
async def run_agent(
3428
agent_path: Optional[str],
3529
) -> None:
@@ -41,11 +35,14 @@ async def run_agent(
4135
Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` file or a built-in agent stored in a Hugging Face dataset.
4236
4337
"""
38+
_patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C
39+
4440
config, prompt = _load_agent_config(agent_path)
4541

4642
servers: List[Dict[str, Any]] = config.get("servers", [])
4743

4844
abort_event = asyncio.Event()
45+
exit_event = asyncio.Event()
4946
first_sigint = True
5047

5148
loop = asyncio.get_running_loop()
@@ -60,8 +57,7 @@ def _sigint_handler() -> None:
6057
return
6158

6259
print("\n[red]Exiting...[/red]", flush=True)
63-
64-
os._exit(130)
60+
exit_event.set()
6561

6662
try:
6763
sigint_registered_in_loop = False
@@ -71,6 +67,7 @@ def _sigint_handler() -> None:
7167
except (AttributeError, NotImplementedError):
7268
# Windows (or any loop that doesn't support it) : fall back to sync
7369
signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
70+
7471
async with Agent(
7572
provider=config.get("provider"),
7673
model=config.get("model"),
@@ -86,8 +83,12 @@ def _sigint_handler() -> None:
8683
while True:
8784
abort_event.clear()
8885

86+
# Check if we should exit
87+
if exit_event.is_set():
88+
break
89+
8990
try:
90-
user_input = await _ainput()
91+
user_input = await _async_prompt(exit_event=exit_event)
9192
first_sigint = True
9293
except EOFError:
9394
print("\n[red]EOF received, exiting.[/red]", flush=True)
@@ -103,6 +104,8 @@ def _sigint_handler() -> None:
103104
async for chunk in agent.run(user_input, abort_event=abort_event):
104105
if abort_event.is_set() and not first_sigint:
105106
break
107+
if exit_event.is_set():
108+
break
106109

107110
if hasattr(chunk, "choices"):
108111
delta = chunk.choices[0].delta

src/huggingface_hub/inference/_mcp/mcp_client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
109109
await self.client.__aexit__(exc_type, exc_val, exc_tb)
110110
await self.cleanup()
111111

112+
async def cleanup(self):
113+
"""Clean up resources"""
114+
await self.client.close()
115+
await self.exit_stack.aclose()
116+
112117
@overload
113118
async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ...
114119

@@ -329,7 +334,3 @@ async def process_single_turn_with_tools(
329334
tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message)
330335
messages.append(tool_message_as_obj)
331336
yield tool_message_as_obj
332-
333-
async def cleanup(self):
334-
"""Clean up resources"""
335-
await self.exit_stack.aclose()

0 commit comments

Comments
 (0)