-
Notifications
You must be signed in to change notification settings - Fork 6
[WIP] feat: add watchdog for runtime session handling #187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| { | ||
| "version": "2.0", | ||
| "resources": [] | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| { | ||
| "servers": { | ||
| "math-server": { | ||
| "coded-math-mcp": { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this related to the watchdog PR? |
||
| "transport": "stdio", | ||
| "command": "python", | ||
| "args": ["server.py"] | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -37,8 +37,14 @@ | |||||
| from .._utils._config import McpServer | ||||||
| from ._context import UiPathServerType | ||||||
| from ._exception import McpErrorCode, UiPathMcpRuntimeError | ||||||
| from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer | ||||||
| from ._session import ( | ||||||
| BaseSessionServer, | ||||||
| SessionHealthInfo, | ||||||
| StdioSessionServer, | ||||||
| StreamableHttpSessionServer, | ||||||
| ) | ||||||
| from ._token_refresh import TokenRefresher | ||||||
| from ._watchdog import SessionWatchdog | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
| tracer = trace.get_tracer(__name__) | ||||||
|
|
@@ -86,6 +92,7 @@ def __init__( | |||||
| self._http_stderr_drain_task: asyncio.Task[None] | None = None | ||||||
| self._http_server_stderr_lines: list[str] = [] | ||||||
| self._uipath = UiPath() | ||||||
| self._watchdog: SessionWatchdog | None = None | ||||||
| self._token_refresher: TokenRefresher | None = None | ||||||
| self._cleanup_done = False | ||||||
|
|
||||||
|
|
@@ -118,6 +125,38 @@ def _validate_auth(self) -> None: | |||||
| UiPathErrorCategory.SYSTEM, | ||||||
| ) | ||||||
|
|
||||||
| def get_sessions(self) -> dict[str, SessionHealthInfo]: | ||||||
| """Return health info for all active sessions (SessionProvider protocol).""" | ||||||
| return { | ||||||
| sid: session.get_health_info() | ||||||
| for sid, session in self._session_servers.items() | ||||||
| } | ||||||
|
|
||||||
| async def remove_session(self, session_id: str, reason: str) -> None: | ||||||
| """Pop, stop, and clean up a single session (SessionProvider protocol).""" | ||||||
| session_server = self._session_servers.pop(session_id, None) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks like a pre-existing race widened by this PR. pop() removes the entry before await session_Server.stop(), and stop() can hold for up to 3s. During that time, a new messageReceived for the same session id will pass the |
||||||
| if session_server is None: | ||||||
| return | ||||||
|
|
||||||
| logger.warning(f"Removing session {session_id}: {reason}") | ||||||
|
|
||||||
| try: | ||||||
| await session_server.stop() | ||||||
| except Exception: | ||||||
| logger.error( | ||||||
| f"Error stopping session {session_id}", | ||||||
| exc_info=True, | ||||||
| ) | ||||||
|
|
||||||
| if session_server.output: | ||||||
| if self.sandboxed: | ||||||
| self._session_output = session_server.output | ||||||
| else: | ||||||
| logger.info(f"Session {session_id} output: {session_server.output}") | ||||||
|
|
||||||
| if self.sandboxed: | ||||||
| self._cancel_event.set() | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This unifies four previously-distinct paths into one, which is a real improvement, but the trailing For HTTP-crash and SignalR-closed this is correct. For watchdog dead-task: in a sandboxed runtime that hosts multiple session servers, one task dying with an exception will now kill the whole runtime. Is this intended? Maybe we can make it explicit, idk: |
||||||
|
|
||||||
| async def get_schema(self) -> UiPathRuntimeSchema: | ||||||
| """Get schema for this MCP runtime. | ||||||
|
|
||||||
|
|
@@ -240,6 +279,9 @@ async def _run_server(self) -> UiPathRuntimeResult: | |||||
| run_task = asyncio.create_task(self._signalr_client.run()) | ||||||
| cancel_task = asyncio.create_task(self._cancel_event.wait()) | ||||||
| self._keep_alive_task = asyncio.create_task(self._keep_alive()) | ||||||
|
|
||||||
| self._watchdog = SessionWatchdog(self) | ||||||
| self._watchdog.start() | ||||||
| self._token_refresher.start() | ||||||
|
|
||||||
| try: | ||||||
|
|
@@ -253,8 +295,8 @@ async def _run_server(self) -> UiPathRuntimeResult: | |||||
| ) | ||||||
| self._cancel_event.set() | ||||||
| finally: | ||||||
| # Cancel any pending tasks gracefully | ||||||
| for task in [run_task, cancel_task, self._keep_alive_task]: | ||||||
| # Cancel pending tasks | ||||||
| for task in [run_task, cancel_task]: | ||||||
| if task and not task.done(): | ||||||
| task.cancel() | ||||||
| try: | ||||||
|
|
@@ -280,7 +322,7 @@ async def _run_server(self) -> UiPathRuntimeResult: | |||||
| except Exception as e: | ||||||
| if isinstance(e, UiPathMcpRuntimeError): | ||||||
| raise | ||||||
| detail = f"Error: {str(e)}" | ||||||
| detail = f"Error: {e}" | ||||||
| raise UiPathMcpRuntimeError( | ||||||
| UiPathErrorCode.EXECUTION_ERROR, | ||||||
| "MCP Runtime execution failed", | ||||||
|
|
@@ -312,11 +354,12 @@ async def _cleanup(self) -> None: | |||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
|
|
||||||
| for session_id, session_server in list(self._session_servers.items()): | ||||||
| try: | ||||||
| await session_server.stop() | ||||||
| except Exception as e: | ||||||
| logger.error(f"Error cleaning up session server {session_id}: {str(e)}") | ||||||
| if self._watchdog: | ||||||
| await self._watchdog.stop() | ||||||
| self._watchdog = None | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The watchdog should be stopped before the long awaits |
||||||
|
|
||||||
| for session_id in list(self._session_servers.keys()): | ||||||
| await self.remove_session(session_id, reason="runtime shutdown") | ||||||
|
|
||||||
| # Stop the shared HTTP server process (streamable-http only) | ||||||
| await self._stop_http_server_process() | ||||||
|
|
@@ -327,46 +370,30 @@ async def _cleanup(self) -> None: | |||||
| try: | ||||||
| await transport._ws.close() | ||||||
| except Exception as e: | ||||||
| logger.error(f"Error closing SignalR WebSocket: {str(e)}") | ||||||
| logger.error(f"Error closing SignalR WebSocket: {e}") | ||||||
|
|
||||||
| # Add a small delay to allow the server to shut down gracefully | ||||||
| if sys.platform == "win32": | ||||||
| await asyncio.sleep(0.5) | ||||||
|
|
||||||
| async def _handle_signalr_session_closed(self, args: list[str]) -> None: | ||||||
| """ | ||||||
| Handle session closed by server. | ||||||
| """ | ||||||
| """Handle session closed by server.""" | ||||||
| if self._cleanup_done: | ||||||
| return | ||||||
|
|
||||||
| if len(args) < 1: | ||||||
| logger.error(f"Received invalid websocket message arguments: {args}") | ||||||
| return | ||||||
|
|
||||||
| session_id = args[0] | ||||||
|
|
||||||
| logger.info(f"Received closed signal for session {session_id}") | ||||||
|
|
||||||
| try: | ||||||
| session_server = self._session_servers.pop(session_id, None) | ||||||
| if session_server: | ||||||
| await session_server.stop() | ||||||
| if session_server.output: | ||||||
| if self.sandboxed: | ||||||
| self._session_output = session_server.output | ||||||
| else: | ||||||
| logger.info( | ||||||
| f"Session {session_id} output: {session_server.output}" | ||||||
| ) | ||||||
| # If this is a sandboxed runtime for a specific session, cancel the execution | ||||||
| if self.sandboxed: | ||||||
| self._cancel_event.set() | ||||||
|
|
||||||
| except Exception as e: | ||||||
| logger.error(f"Error terminating session {session_id}: {str(e)}") | ||||||
| await self.remove_session(session_id, reason="server closed") | ||||||
|
|
||||||
| async def _handle_signalr_message(self, args: list[str]) -> None: | ||||||
| """ | ||||||
| Handle incoming SignalR messages. | ||||||
| """ | ||||||
| """Handle incoming SignalR messages.""" | ||||||
| if self._cleanup_done: | ||||||
| return | ||||||
|
|
||||||
| if len(args) < 2: | ||||||
| logger.error(f"Received invalid websocket message arguments: {args}") | ||||||
| return | ||||||
|
|
@@ -392,7 +419,7 @@ async def _handle_signalr_message(self, args: list[str]) -> None: | |||||
| await session_server.start() | ||||||
| except Exception as e: | ||||||
| logger.error( | ||||||
| f"Error starting session server for session {session_id}: {str(e)}" | ||||||
| f"Error starting session server for session {session_id}: {e}" | ||||||
| ) | ||||||
| await self._on_session_start_error(session_id) | ||||||
| raise | ||||||
|
|
@@ -406,7 +433,7 @@ async def _handle_signalr_message(self, args: list[str]) -> None: | |||||
|
|
||||||
| except Exception as e: | ||||||
| logger.error( | ||||||
| f"Error handling websocket notification for session {session_id}: {str(e)}" | ||||||
| f"Error handling websocket notification for session {session_id}: {e}" | ||||||
| ) | ||||||
|
|
||||||
| async def _handle_signalr_error(self, error: Any) -> None: | ||||||
|
|
@@ -421,17 +448,21 @@ async def _handle_signalr_close(self) -> None: | |||||
| """Handle SignalR connection close event.""" | ||||||
| logger.info("Websocket connection closed.") | ||||||
|
|
||||||
| async def _start_http_server_process(self) -> None: | ||||||
| """Spawn the streamable-http server process. | ||||||
|
|
||||||
| The process is started once and shared across all sessions. | ||||||
| """ | ||||||
| def _get_server_env(self) -> dict[str, str]: | ||||||
| """Return server env vars, with os.environ merged in for Coded servers.""" | ||||||
| env_vars = self._server.env.copy() | ||||||
| if self.server_type is UiPathServerType.Coded: | ||||||
| for name, value in os.environ.items(): | ||||||
| if name not in env_vars: | ||||||
| env_vars[name] = value | ||||||
| return env_vars | ||||||
|
|
||||||
| async def _start_http_server_process(self) -> None: | ||||||
| """Spawn the streamable-http server process. | ||||||
|
|
||||||
| The process is started once and shared across all sessions. | ||||||
| """ | ||||||
| env_vars = self._get_server_env() | ||||||
| merged_env = {**os.environ, **env_vars} if env_vars else None | ||||||
| self._http_server_stderr_lines = [] | ||||||
| self._http_server_process = await asyncio.create_subprocess_exec( | ||||||
|
|
@@ -472,7 +503,12 @@ async def _wait_for_http_server_ready( | |||||
|
|
||||||
| url = self._server.url | ||||||
| if not url: | ||||||
| raise ValueError("streamable-http transport requires url in config") | ||||||
| raise UiPathMcpRuntimeError( | ||||||
| McpErrorCode.CONFIGURATION_ERROR, | ||||||
| "Missing URL for streamable-http server", | ||||||
| "Please specify a 'url' in the server configuration for streamable-http transport.", | ||||||
| UiPathErrorCategory.SYSTEM, | ||||||
| ) | ||||||
|
|
||||||
| for attempt in range(max_retries): | ||||||
| # Check if process has crashed | ||||||
|
|
@@ -561,13 +597,9 @@ async def _monitor_http_server_process(self) -> None: | |||||
| # Stop all HTTP sessions, they will fail on next request anyway | ||||||
| for session_id, session_server in list(self._session_servers.items()): | ||||||
| if isinstance(session_server, StreamableHttpSessionServer): | ||||||
| try: | ||||||
| await session_server.stop() | ||||||
| except Exception as e: | ||||||
| logger.error( | ||||||
| f"Error stopping session {session_id} after process crash: {e}" | ||||||
| ) | ||||||
| self._session_servers.pop(session_id, None) | ||||||
| await self.remove_session( | ||||||
| session_id, reason="http process crash" | ||||||
| ) | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
|
|
||||||
|
|
@@ -577,14 +609,6 @@ async def _register(self) -> None: | |||||
| initialization_successful = False | ||||||
| tools_result: ListToolsResult | None = None | ||||||
| server_stderr_output = "" | ||||||
| env_vars = self._server.env | ||||||
|
|
||||||
| # if server is Coded, include environment variables | ||||||
| if self.server_type is UiPathServerType.Coded: | ||||||
| for name, value in os.environ.items(): | ||||||
| # config env variables should have precedence over system ones | ||||||
| if name not in env_vars: | ||||||
| env_vars[name] = value | ||||||
|
|
||||||
| try: | ||||||
| if self._server.is_streamable_http: | ||||||
|
|
@@ -624,7 +648,7 @@ async def _register(self) -> None: | |||||
| server_params = StdioServerParameters( | ||||||
| command=self._server.command, | ||||||
| args=self._server.args, | ||||||
| env=env_vars, | ||||||
| env=self._get_server_env(), | ||||||
| ) | ||||||
|
|
||||||
| with tempfile.TemporaryFile(mode="w+b") as stderr_temp_binary: | ||||||
|
|
@@ -754,41 +778,39 @@ async def _on_session_start_error(self, session_id: str) -> None: | |||||
| f"Error sending session dispose signal to UiPath MCP Server: {e}" | ||||||
| ) | ||||||
|
|
||||||
| async def _on_keep_alive_response(self, response: CompletionMessage) -> None: | ||||||
| """Handle keep-alive response: log session state, detect orphaned sandboxed runtimes.""" | ||||||
| if response.error: | ||||||
| logger.error(f"Error during keep-alive: {response.error}") | ||||||
| return | ||||||
| session_ids = response.result | ||||||
| logger.info(f"Server active sessions: {session_ids}") | ||||||
| runtime_sessions = {} | ||||||
| for sid, s in self._session_servers.items(): | ||||||
| health = s.get_health_info() | ||||||
| runtime_sessions[sid] = { | ||||||
| "task_done": health.task_done, | ||||||
| "active_requests": len(s._active_requests), | ||||||
|
||||||
| "active_requests": len(s._active_requests), | |
| "active_requests": health.active_request_count, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reaching into the session's private attribute defeats the purpose of SessionHealthInfo. The PR description says active_request_count was removed from SessionHealthInfo, but this is the use. Maybe we should restore it
Also, I see queue_size from SessionHealthInfo is never consumed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change needed?