refac: code intepreter

This commit is contained in:
Timothy Jaeryang Baek 2025-02-10 13:12:05 -08:00
parent 610f9d039a
commit a273cba0fb
2 changed files with 50 additions and 23 deletions

View File

@ -18,6 +18,7 @@ async def execute_code_jupyter(
:param password: Jupyter password (optional)
:param timeout: WebSocket timeout in seconds (default: 10s)
:return: Dictionary with stdout, stderr, and result
- Images are prefixed with "base64:image/png," and separated by newlines if multiple.
"""
session = requests.Session() # Maintain cookies
headers = {} # Headers for requests
@ -28,20 +29,15 @@ async def execute_code_jupyter(
login_url = urljoin(jupyter_url, "/login")
response = session.get(login_url)
response.raise_for_status()
# Retrieve `_xsrf` token
xsrf_token = session.cookies.get("_xsrf")
if not xsrf_token:
raise ValueError("Failed to fetch _xsrf token")
# Send login request
login_data = {"_xsrf": xsrf_token, "password": password}
login_response = session.post(
login_url, data=login_data, cookies=session.cookies
)
login_response.raise_for_status()
# Update headers with `_xsrf`
headers["X-XSRFToken"] = xsrf_token
except Exception as e:
return {
@ -55,18 +51,15 @@ async def execute_code_jupyter(
kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
try:
# Include cookies if authenticating with password
response = session.post(kernel_url, headers=headers, cookies=session.cookies)
response.raise_for_status()
kernel_id = response.json()["id"]
# Construct WebSocket URL
websocket_url = urljoin(
jupyter_url.replace("http", "ws"),
f"/api/kernels/{kernel_id}/channels{params}",
)
# **IMPORTANT:** Include authentication cookies for WebSockets
ws_headers = {}
if password and not token:
ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
@ -75,13 +68,10 @@ async def execute_code_jupyter(
[f"{name}={value}" for name, value in cookies.items()]
)
# Connect to the WebSocket
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
msg_id = str(uuid.uuid4())
# Send execution request
execute_request = {
"header": {
"msg_id": msg_id,
@ -105,37 +95,47 @@ async def execute_code_jupyter(
}
await ws.send(json.dumps(execute_request))
# Collect execution results
stdout, stderr, result = "", "", None
stdout, stderr, result = "", "", []
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout)
message_data = json.loads(message)
if message_data.get("parent_header", {}).get("msg_id") == msg_id:
msg_type = message_data.get("msg_type")
if msg_type == "stream":
if message_data["content"]["name"] == "stdout":
stdout += message_data["content"]["text"]
elif message_data["content"]["name"] == "stderr":
stderr += message_data["content"]["text"]
elif msg_type in ("execute_result", "display_data"):
result = message_data["content"]["data"].get(
"text/plain", ""
)
data = message_data["content"]["data"]
if "image/png" in data:
result.append(
f"data:image/png;base64,{data['image/png']}"
)
elif "text/plain" in data:
result.append(data["text/plain"])
elif msg_type == "error":
stderr += "\n".join(message_data["content"]["traceback"])
elif (
msg_type == "status"
and message_data["content"]["execution_state"] == "idle"
):
break
except asyncio.TimeoutError:
stderr += "\nExecution timed out."
break
except Exception as e:
return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
finally:
# Shutdown the kernel
if kernel_id:
requests.delete(
f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
@ -144,10 +144,5 @@ async def execute_code_jupyter(
return {
"stdout": stdout.strip(),
"stderr": stderr.strip(),
"result": result.strip() if result else "",
"result": "\n".join(result).strip() if result else "",
}
# Example Usage
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token"))
# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password"))

View File

@ -1723,6 +1723,38 @@ async def process_chat_response(
)
output["stdout"] = "\n".join(stdoutLines)
result = output.get("result", "")
if result:
resultLines = result.split("\n")
for idx, line in enumerate(resultLines):
if "data:image/png;base64" in line:
id = str(uuid4())
# ensure the path exists
os.makedirs(
os.path.join(CACHE_DIR, "images"),
exist_ok=True,
)
image_path = os.path.join(
CACHE_DIR,
f"images/{id}.png",
)
with open(image_path, "wb") as f:
f.write(
base64.b64decode(
line.split(",")[1]
)
)
resultLines[idx] = (
f"![Output Image {idx}](/cache/images/{id}.png)"
)
output["result"] = "\n".join(resultLines)
except Exception as e:
output = str(e)