This commit is contained in:
Timothy Jaeryang Baek 2025-03-03 20:16:40 -08:00
parent bb2bd7d721
commit e628bfe6ff

View File

@ -7,7 +7,6 @@ from typing import Optional
import aiohttp
import websockets
from pydantic import BaseModel
from websockets import ClientConnection
from open_webui.env import SRC_LOG_LEVELS
@ -30,7 +29,14 @@ class JupyterCodeExecuter:
Execute code in jupyter notebook
"""
def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
def __init__(
self,
base_url: str,
code: str,
token: str = "",
password: str = "",
timeout: int = 60,
):
"""
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
:param code: Code to execute
@ -54,7 +60,9 @@ class JupyterCodeExecuter:
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.kernel_id:
try:
async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response:
async with self.session.delete(
f"/api/kernels/{self.kernel_id}", params=self.params
) as response:
response.raise_for_status()
except Exception as err:
logger.exception("close kernel failed, %s", err)
@ -81,7 +89,9 @@ class JupyterCodeExecuter:
self.session.cookie_jar.update_cookies(response.cookies)
self.session.headers.update({"X-XSRFToken": xsrf_token})
async with self.session.post(
"/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
"/login",
data={"_xsrf": xsrf_token, "password": self.password},
allow_redirects=False,
) as response:
response.raise_for_status()
self.session.cookie_jar.update_cookies(response.cookies)
@ -91,7 +101,9 @@ class JupyterCodeExecuter:
self.params.update({"token": self.token})
async def init_kernel(self) -> None:
async with self.session.post(url="/api/kernels", params=self.params) as response:
async with self.session.post(
url="/api/kernels", params=self.params
) as response:
response.raise_for_status()
kernel_data = await response.json()
self.kernel_id = kernel_data["id"]
@ -103,7 +115,12 @@ class JupyterCodeExecuter:
ws_headers = {}
if self.password and not self.token:
ws_headers = {
"Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
"Cookie": "; ".join(
[
f"{cookie.key}={cookie.value}"
for cookie in self.session.cookie_jar
]
),
**self.session.headers,
}
return websocket_url, ws_headers
@ -112,10 +129,12 @@ class JupyterCodeExecuter:
# initialize ws
websocket_url, ws_headers = self.init_ws()
# execute
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
async with websockets.connect(
websocket_url, additional_headers=ws_headers
) as ws:
await self.execute_in_jupyter(ws)
async def execute_in_jupyter(self, ws: ClientConnection) -> None:
async def execute_in_jupyter(self, ws) -> None:
# send message
msg_id = uuid.uuid4().hex
await ws.send(
@ -184,6 +203,8 @@ class JupyterCodeExecuter:
async def execute_code_jupyter(
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
) -> dict:
async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
async with JupyterCodeExecuter(
base_url, code, token, password, timeout
) as executor:
result = await executor.run()
return result.model_dump()