fix(jupyter): fix kernel_id not set and optimize code

This commit is contained in:
orenzhang 2025-03-03 22:05:50 +08:00
parent 3d6e48b05e
commit 744ffbb1fb
No known key found for this signature in database
GPG Key ID: 73D45F78147E506C

View File

@ -1,14 +1,15 @@
import asyncio
import json
import uuid
from typing import Optional
import httpx
import websockets
import requests
from urllib.parse import urljoin
async def execute_code_jupyter(
jupyter_url, code, token=None, password=None, timeout=10
):
jupyter_url: str, code: str, token: str = None, password: str = None, timeout: int = 60
) -> Optional[dict]:
"""
Executes Python code in a Jupyter kernel.
Supports authentication with a token or password.
@ -20,80 +21,70 @@ async def execute_code_jupyter(
:return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
"""
session = requests.Session() # Maintain cookies
headers = {} # Headers for requests
# Authenticate using password
jupyter_url = jupyter_url.rstrip("/")
client = httpx.AsyncClient(base_url=jupyter_url, timeout=timeout, follow_redirects=True)
headers = {}
# password authentication
if password and not token:
try:
login_url = urljoin(jupyter_url, "/login")
response = session.get(login_url)
response = await client.get("/login")
response.raise_for_status()
xsrf_token = session.cookies.get("_xsrf")
xsrf_token = response.cookies.get("_xsrf")
if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token")
login_data = {"_xsrf": xsrf_token, "password": password}
login_response = session.post(
login_url, data=login_data, cookies=session.cookies
)
login_response.raise_for_status()
raise ValueError("_xsrf token not found")
response = await client.post("/login", data={"_xsrf": xsrf_token, "password": password})
response.raise_for_status()
headers["X-XSRFToken"] = xsrf_token
except Exception as e:
return {
"stdout": "",
"stderr": f"Authentication Error: {str(e)}",
"result": "",
}
return {"stdout": "", "stderr": f"Authentication Error: {str(e)}", "result": ""}
# Construct API URLs with authentication token if provided
params = f"?token={token}" if token else ""
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
# token authentication
params = {"token": token} if token else {}
kernel_id = ""
try:
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
response = await client.post(url="/api/kernels", params=params, headers=headers)
response.raise_for_status()
kernel_id = response.json()["id"]
websocket_url = urljoin(
jupyter_url.replace("http", "ws"),
f"/api/kernels/{kernel_id}/channels{params}",
)
ws_base = jupyter_url.replace("http", "ws")
websocket_url = f"{ws_base}/api/kernels/{kernel_id}/channels" + (f"?token={token}" if token else "")
ws_headers = {}
if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
cookies = {name: value for name, value in session.cookies.items()}
ws_headers["Cookie"] = "; ".join(
[f"{name}={value}" for name, value in cookies.items()]
)
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
msg_id = str(uuid.uuid4())
execute_request = {
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
ws_headers = {
"X-XSRFToken": client.cookies.get("_xsrf"),
"Cookie": "; ".join([f"{name}={value}" for name, value in client.cookies.items()]),
}
await ws.send(json.dumps(execute_request))
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
msg_id = str(uuid.uuid4())
await ws.send(
json.dumps(
{
"header": {
"msg_id": msg_id,
"msg_type": "execute_request",
"username": "user",
"session": str(uuid.uuid4()),
"date": "",
"version": "5.3",
},
"parent_header": {},
"metadata": {},
"content": {
"code": code,
"silent": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
"stop_on_error": True,
},
"channel": "shell",
}
)
)
stdout, stderr, result = "", "", []
@ -101,32 +92,27 @@ async def execute_code_jupyter(
try:
message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
msg_type = message_data.get("msg_type")
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
continue
if msg_type == "stream":
msg_type = message_data.get("msg_type")
match msg_type:
case "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
elif msg_type in ("execute_result", "display_data"):
case "execute_result" | "display_data":
data = message_data["content"]["data"]
if "image/png" in data:
result.append(
f"data:image/png;base64,{data['image/png']}"
)
result.append(f"data:image/png;base64,{data['image/png']}")
elif "text/plain" in data:
result.append(data["text/plain"])
elif msg_type == "error":
case "error":
stderr += "\n".join(message_data["content"]["traceback"])
elif (
msg_type == "status"
and message_data["content"]["execution_state"] == "idle"
):
break
case "status":
if message_data["content"]["execution_state"] == "idle":
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
@ -137,12 +123,7 @@ async def execute_code_jupyter(
finally:
if kernel_id:
requests.delete(
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
)
await client.delete(f"/api/kernels/{kernel_id}", headers=headers, params=params)
await client.aclose()
return {
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"result": "\n".join(result).strip() if result else "",
}
return {"stdout": stdout.strip(), "stderr": stderr.strip(), "result": "\n".join(result).strip() if result else ""}