refactor(jupyter): replace httpx with aiohttp

This commit is contained in:
orenzhang 2025-03-04 11:55:01 +08:00
parent 744ffbb1fb
commit 96e0c4d7b8
No known key found for this signature in database
GPG Key ID: 73D45F78147E506C

View File

@ -1,129 +1,185 @@
import asyncio
import json
import logging
import uuid
from typing import Optional
import httpx
import aiohttp
import websockets
from pydantic import BaseModel
from websockets import ClientConnection
logger = logging.getLogger(__name__)
class ResultModel(BaseModel):
"""
Execute Code Result Model
"""
stdout: Optional[str] = ""
stderr: Optional[str] = ""
result: Optional[str] = ""
class JupyterCodeExecuter:
"""
Execute code in jupyter notebook
"""
def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
"""
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 60s)
"""
self.base_url = base_url.rstrip("/")
self.code = code
self.token = token
self.password = password
self.timeout = timeout
self.kernel_id = ""
self.session = aiohttp.ClientSession(base_url=self.base_url)
self.params = {}
self.result = ResultModel()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.kernel_id:
try:
await self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params)
except Exception as err:
logger.exception("close kernel failed, %s", err)
await self.session.close()
async def run(self) -> ResultModel:
try:
await self.sign_in()
await self.init_kernel()
await self.execute_code()
except Exception as err:
logger.error(err)
self.result.stderr = f"Error: {err}"
return self.result
async def sign_in(self) -> None:
# password authentication
if self.password and not self.token:
async with self.session.get("/login") as response:
response.raise_for_status()
xsrf_token = response.cookies["_xsrf"].value
if not xsrf_token:
raise ValueError("_xsrf token not found")
self.session.cookie_jar.update_cookies(response.cookies)
self.session.headers.update({"X-XSRFToken": xsrf_token})
async with self.session.post(
"/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
) as response:
response.raise_for_status()
self.session.cookie_jar.update_cookies(response.cookies)
# token authentication
if self.token:
self.params.update({"token": self.token})
async def init_kernel(self) -> None:
async with self.session.post(url="/api/kernels", params=self.params) as response:
response.raise_for_status()
kernel_data = await response.json()
self.kernel_id = kernel_data["id"]
def init_ws(self) -> (str, dict):
ws_base = self.base_url.replace("http", "ws")
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
ws_headers = {}
if self.password and not self.token:
ws_headers = {
"Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
**self.session.headers,
}
return websocket_url, ws_headers
async def execute_code(self) -> None:
# initialize ws
websocket_url, ws_headers = self.init_ws()
# execute
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
await self.execute_in_jupyter(ws)
async def execute_in_jupyter(self, ws: ClientConnection) -> None:
# send message
msg_id = uuid.uuid4().hex
await ws.send(
json.dumps(
{
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": uuid.uuid4().hex,
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": self.code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
)
)
# parse message
stdout, stderr, result = "", "", []
while True:
try:
# wait for message
message = await asyncio.wait_for(ws.recv(), self.timeout)
message_data = json.loads(message)
# msg id not match, skip
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
continue
# check message type
msg_type = message_data.get("msg_type")
match msg_type:
case "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
case "execute_result" | "display_data":
data = message_data["content"]["data"]
if "image/png" in data:
result.append(f"data:image/png;base64,{data['image/png']}")
elif "text/plain" in data:
result.append(data["text/plain"])
case "error":
stderr += "\n".join(message_data["content"]["traceback"])
case "status":
if message_data["content"]["execution_state"] == "idle":
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
self.result.stdout = stdout.strip()
self.result.stderr = stderr.strip()
self.result.result = "\n".join(result).strip() if result else ""
async def execute_code_jupyter(
jupyter_url: str, code: str, token: str = None, password: str = None, timeout: int = 60
) -> Optional[dict]:
"""
Executes Python code in a Jupyter kernel.
Supports authentication with a token or password.
:param jupyter_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
:param token: Jupyter authentication token (optional)
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 10s)
:return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
"""
jupyter_url = jupyter_url.rstrip("/")
client = httpx.AsyncClient(base_url=jupyter_url, timeout=timeout, follow_redirects=True)
headers = {}
# password authentication
if password and not token:
try:
response = await client.get("/login")
response.raise_for_status()
xsrf_token = response.cookies.get("_xsrf")
if not xsrf_token:
raise ValueError("_xsrf token not found")
response = await client.post("/login", data={"_xsrf": xsrf_token, "password": password})
response.raise_for_status()
headers["X-XSRFToken"] = xsrf_token
except Exception as e:
return {"stdout": "", "stderr": f"Authentication Error: {str(e)}", "result": ""}
# token authentication
params = {"token": token} if token else {}
kernel_id = ""
try:
response = await client.post(url="/api/kernels", params=params, headers=headers)
response.raise_for_status()
kernel_id = response.json()["id"]
ws_base = jupyter_url.replace("http", "ws")
websocket_url = f"{ws_base}/api/kernels/{kernel_id}/channels" + (f"?token={token}" if token else "")
ws_headers = {}
if password and not token:
ws_headers = {
"X-XSRFToken": client.cookies.get("_xsrf"),
"Cookie": "; ".join([f"{name}={value}" for name, value in client.cookies.items()]),
}
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
msg_id = str(uuid.uuid4())
await ws.send(
json.dumps(
{
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
)
)
stdout, stderr, result = "", "", []
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
continue
msg_type = message_data.get("msg_type")
match msg_type:
case "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
case "execute_result" | "display_data":
data = message_data["content"]["data"]
if "image/png" in data:
result.append(f"data:image/png;base64,{data['image/png']}")
elif "text/plain" in data:
result.append(data["text/plain"])
case "error":
stderr += "\n".join(message_data["content"]["traceback"])
case "status":
if message_data["content"]["execution_state"] == "idle":
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
except Exception as e:
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
finally:
if kernel_id:
await client.delete(f"/api/kernels/{kernel_id}", headers=headers, params=params)
await client.aclose()
return {"stdout": stdout.strip(), "stderr": stderr.strip(), "result": "\n".join(result).strip() if result else ""}
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
) -> dict:
async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
result = await executor.run()
return result.model_dump()