diff --git a/lnbits/extensions/copilot/views.py b/lnbits/extensions/copilot/views.py index 87ce8a342..46a82ca6d 100644 --- a/lnbits/extensions/copilot/views.py +++ b/lnbits/extensions/copilot/views.py @@ -1,26 +1,22 @@ from quart import g, abort, render_template, jsonify, websocket from http import HTTPStatus import httpx - +from collections import defaultdict from lnbits.decorators import check_user_exists, validate_uuids - from . import copilot_ext from .crud import get_copilot - from quart import g, abort, render_template, jsonify, websocket from functools import wraps import trio import shortuuid from . import copilot_ext - @copilot_ext.route("/") @validate_uuids(["usr"], required=True) @check_user_exists() async def index(): return await render_template("copilot/index.html", user=g.user) - @copilot_ext.route("/cp/") async def compose(): return await render_template("copilot/compose.html") @@ -36,58 +32,27 @@ async def panel(): # socket_relay is a list where the control panel or # lnurl endpoints can leave a message for the compose window -socket_relay = {} +connected_websockets = defaultdict(set) +@copilot_ext.websocket("/ws//") +async def wss(id): + copilot = await get_copilot(id) + if not copilot: + return "", HTTPStatus.FORBIDDEN + global connected_websockets + send_channel, receive_channel = trio.open_memory_channel(0) + connected_websockets[id].add(send_channel) + try: + while True: + data = await receive_channel.receive() + await websocket.send(data) + finally: + connected_websockets[id].remove(send_channel) -@copilot_ext.websocket("/ws/panel/") -async def ws_panel(copilot_id): - global socket_relay - while True: - data = await websocket.receive() - socket_relay[copilot_id] = shortuuid.uuid()[:5] + "-" + data + "-" + "none" - - -@copilot_ext.websocket("/ws/compose/") -async def ws_compose(copilot_id): - global socket_relay - while True: - data = await websocket.receive() - await websocket.send(socket_relay[copilot_id]) - - -async def updater(data, comment, copilot): - global socket_relay - socket_relay[copilot] = shortuuid.uuid()[:5] + "-" + str(data) + "-" + str(comment) - - - - -##################WEBSOCKET ROUTES######################## - -# socket_relay is a list where the control panel or -# lnurl endpoints can leave a message for the compose window - -connected_websockets = set() - - -def collect_websocket(func): - @wraps(func) - async def wrapper(*args, **kwargs): - global connected_websockets - send_channel, receive_channel = trio.open_memory_channel(0) - connected_websockets.add(send_channel) - try: - return await func(receive_channel, *args, **kwargs) - finally: - connected_websockets.remove(send_channel) - - return wrapper - - -@copilot_ext.websocket("/ws") -@collect_websocket -async def wss(receive_channel): - - while True: - data = await receive_channel.receive() - await websocket.send(data) \ No newline at end of file +async def updater(copilot_id, data, comment): + copilot = await get_copilot(copilot_id) + if not copilot: + return + print(connected_websockets) + for queue in connected_websockets[copilot_id]: + await queue.send(f"{data + '-' + comment}") \ No newline at end of file diff --git a/lnbits/extensions/copilot/views_api.py b/lnbits/extensions/copilot/views_api.py index fab571b12..76b2d54cd 100644 --- a/lnbits/extensions/copilot/views_api.py +++ b/lnbits/extensions/copilot/views_api.py @@ -100,5 +100,8 @@ async def api_copilot_ws_relay(copilot_id, comment, data): if not copilot: return jsonify({"message": "copilot does not exist"}), HTTPStatus.NOT_FOUND - await updater(data, comment, copilot_id) + try: + await updater(copilot_id, data, comment) + except: + return "", HTTPStatus.FORBIDDEN return "", HTTPStatus.OK