diff --git a/lnbits/utils/gateway.py b/lnbits/utils/gateway.py index b5045fb46..a9a3a3e1b 100644 --- a/lnbits/utils/gateway.py +++ b/lnbits/utils/gateway.py @@ -3,7 +3,7 @@ import uuid from asyncio import Queue, TimeoutError from http import HTTPStatus from json import dumps, loads -from typing import Any, AsyncIterator, Awaitable, Dict, Mapping, Optional +from typing import Any, AsyncIterator, Awaitable, Dict, Iterator, Mapping, Optional from urllib.parse import urlencode import httpx @@ -13,7 +13,7 @@ from loguru import logger from websocket import WebSocketApp from lnbits.settings import settings - +from contextlib import asynccontextmanager class HTTPTunnelClient: @@ -215,14 +215,24 @@ class HTTPTunnelClient: timeout=timeout, ) - async def aiter_text(self) -> AsyncIterator[str]: - for chunk in await self._chunks.get(): - print("### chunk", chunk) - yield chunk - async def aclose(self) -> None: self.disconnect() + @asynccontextmanager + async def stream( + self, + method: str, + url: str, + + ) -> AsyncIterator["HTTPTunnelResponse"]: + + response = HTTPTunnelResponse(queue=self._chunks) + try: + yield response + finally: + await response.aclose() + + async def _handle_response(self, resp: Optional[dict]): if not resp: return @@ -247,8 +257,10 @@ class HTTPTunnelClient: class HTTPTunnelResponse: # status code, detail - def __init__(self, resp: Optional[dict]): + def __init__(self, resp: Optional[dict] = None, queue: Optional[Queue] = None): self._resp = resp + self._queue = queue + self._running = True @property def is_error(self) -> bool: @@ -283,7 +295,17 @@ class HTTPTunnelResponse: body = self.text return loads(body, **kwargs) if body else None + async def aiter_text( + self + ) -> AsyncIterator[str]: + if not self._queue: + return + while self._running: + data = await self._queue.get() + yield data + async def aclose(self) -> None: + self._running = False class HTTPInternalCall: def __init__(self, routers: APIRouter, x_api_key: str): diff --git a/lnbits/wallets/lnbits.py b/lnbits/wallets/lnbits.py index 9a8a7fcb5..7c4c00057 100644 --- a/lnbits/wallets/lnbits.py +++ b/lnbits/wallets/lnbits.py @@ -205,32 +205,14 @@ class LNbitsWallet(Wallet): while settings.lnbits_running: try: - async with httpx.AsyncClient( - timeout=None, headers=self.headers - ) as client: - del client.headers[ - "accept-encoding" - ] # we have to disable compression for SSEs - async with client.stream( - "GET", url, content="text/event-stream" + async with self.client.stream( + "GET", url ) as r: - sse_trigger = False - async for line in r.aiter_lines(): - # The data we want to listen to is of this shape: - # event: payment-received - # data: {.., "payment_hash" : "asd"} - if line.startswith("event: payment-received"): - sse_trigger = True - continue - elif sse_trigger and line.startswith("data:"): - data = json.loads(line[len("data:") :]) - sse_trigger = False - yield data["payment_hash"] - else: - sse_trigger = False + async for data in r.aiter_text(): + yield data["payment_hash"] - except (OSError, httpx.ReadError, httpx.ConnectError, httpx.ReadTimeout): - pass + except Exception as exc: + logger.warning(exc) logger.error( "lost connection to lnbits /payments/sse, retrying in 5 seconds"