Skip to content

Commit

Permalink
refactor: Enhance session management in mcp_tools.py by implementing …
Browse files Browse the repository at this point in the history
…async context manager for better resource handling and error management
  • Loading branch information
onuratakan committed Jan 6, 2025
1 parent bd95e19 commit 759acbf
Showing 1 changed file with 54 additions and 39 deletions.
93 changes: 54 additions & 39 deletions src/upsonicai/server/tools/server/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.stdio import get_default_environment
import asyncio
from contextlib import asynccontextmanager
# Create server parameters for stdio connection

from .api import app, timeout
Expand All @@ -26,65 +28,78 @@ class ListToolsRequest(BaseRequestMCP):
pass


# Global variable to store the session
session_store = None

async def get_session(command: str, args: list, env: dict):
@asynccontextmanager
async def managed_session(command: str, args: list, env: dict | None = None):
print("env", env)
print("args", args)
print("command", command)

if env is None:
env = get_default_environment()
else:
default_env = get_default_environment()
default_env.update(env)
env = default_env

server_params = StdioServerParameters(
command=command, # Executable
args=args, # Optional command line arguments
env=env, # Environment variables
command=command,
args=args,
env=env,
)

client = None
session = None

try:
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
print("Initializing session...")
await session.initialize()
print("Session initialized.")
yield session
except GeneratorExit:
print("GeneratorExit occurred, session closed.")
except Exception as e:
print(f"Error in session: {e}")
raise
client = stdio_client(server_params)
read, write = await client.__aenter__()
session = ClientSession(read, write)
await session.__aenter__()
await session.initialize()
yield session
finally:
print("Session cleanup.")
if session:
try:
await session.__aexit__(None, None, None)
except Exception:
pass
if client:
try:
await client.__aexit__(None, None, None)
except Exception:
pass


@app.post(f"{prefix}/tools")
@timeout(30.0)
async def list_tools(request: ListToolsRequest):
print("Listing tools...")

async for session in get_session(request.command, request.args, request.env):
try:
try:
async with managed_session(request.command, request.args, request.env) as session:
tools = await session.list_tools()
print(f"Tools listed: {tools}")
except Exception as e:
print(f"Error listing tools: {e}")
raise HTTPException(status_code=500, detail="Failed to list tools")
return {"available_tools": tools}
return {"available_tools": tools}
except asyncio.CancelledError:
raise HTTPException(status_code=408, detail="Request timeout")
except Exception as e:
print(f"Error listing tools: {e}")
raise HTTPException(status_code=500, detail="Failed to list tools")


@app.post(f"{prefix}/call_tool")
@timeout(30.0)
async def call_tool(request: ToolRequest):
print(f"Received tool call request: {request}")

async for session in get_session(request.command, request.args, request.env):
try:
print(
f"Calling tool: {request.tool_name} with arguments: {request.arguments}"
)
result = await session.call_tool(
request.tool_name, arguments=request.arguments
)

try:
async with managed_session(request.command, request.args, request.env) as session:
print(f"Calling tool: {request.tool_name} with arguments: {request.arguments}")
result = await session.call_tool(request.tool_name, arguments=request.arguments)
print(f"Tool call result: {result}")
except Exception as e:
print(f"Error calling tool: {e}")
raise HTTPException(status_code=500, detail="Failed to call tool")
return {"result": result}
return {"result": result}
except asyncio.CancelledError:
raise HTTPException(status_code=408, detail="Request timeout")
except Exception as e:
print(f"Error calling tool: {e}")
raise HTTPException(status_code=500, detail="Failed to call tool")

0 comments on commit 759acbf

Please sign in to comment.