diff --git a/lnbits/app.py b/lnbits/app.py index d0ecb510c..8d9d64fa7 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -21,6 +21,7 @@ from slowapi import Limiter from slowapi.util import get_remote_address from starlette.responses import JSONResponse +from lnbits.cache import cache from lnbits.core.crud import get_installed_extensions from lnbits.core.helpers import migrate_extension_database from lnbits.core.services import websocketUpdater @@ -330,6 +331,8 @@ def register_startup(app: FastAPI): if settings.lnbits_admin_ui: initialize_server_logger() + asyncio.create_task(cache.invalidate_forever()) + except Exception as e: logger.error(str(e)) raise ImportError("Failed to run 'startup' event.") diff --git a/lnbits/cache.py b/lnbits/cache.py new file mode 100644 index 000000000..68ec9808b --- /dev/null +++ b/lnbits/cache.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import asyncio +from time import time +from typing import Any, NamedTuple, Optional + +from loguru import logger + + +class Cached(NamedTuple): + value: Any + expiry: float + + +class Cache: + """ + Small caching utility providing simple get/set interface (very much like redis) + """ + + def __init__(self): + self._values: dict[Any, Cached] = {} + + def get(self, key: str, default=None) -> Optional[Any]: + cached = self._values.get(key) + if cached is not None: + if cached.expiry > time(): + return cached.value + else: + self._values.pop(key) + return default + + def set(self, key: str, value: Any, expiry: float = 10): + self._values[key] = Cached(value, time() + expiry) + + def pop(self, key: str, default=None) -> Optional[Any]: + cached = self._values.pop(key, None) + if cached and cached.expiry > time(): + return cached.value + return default + + async def save_result(self, coro, key: str, expiry: float = 10): + """ + If `key` exists, return its value, otherwise call coro and cache its result + """ + cached = self.get(key) + if cached: + return cached + else: + value = await coro() + self.set(key, value, expiry=expiry) + return value + + async def invalidate_forever(self, interval: float = 10): + while True: + try: + await asyncio.sleep(interval) + ts = time() + expired = [k for k, v in self._values.items() if v.expiry < ts] + for k in expired: + self._values.pop(k) + except Exception: + logger.error("Error invalidating cache") + + +cache = Cache() diff --git a/tests/core/test_cache.py b/tests/core/test_cache.py new file mode 100644 index 000000000..9e3e93527 --- /dev/null +++ b/tests/core/test_cache.py @@ -0,0 +1,56 @@ +import asyncio + +import pytest + +from lnbits.cache import Cache +from tests.conftest import pytest_asyncio + + +@pytest_asyncio.fixture(scope="session") +async def cache(): + cache = Cache() + + task = asyncio.create_task(cache.invalidate_forever(interval=0.1)) + yield cache + task.cancel() + + +key = "foo" +value = "bar" + + +@pytest.mark.asyncio +async def test_cache_get_set(cache): + cache.set(key, value) + assert cache.get(key) == value + assert cache.get(key, default="default") == value + assert cache.get("i-dont-exist", default="default") == "default" + + +@pytest.mark.asyncio +async def test_cache_expiry(cache): + cache.set(key, value, expiry=0.1) + await asyncio.sleep(0.2) + assert not cache.get(key) + + +@pytest.mark.asyncio +async def test_cache_pop(cache): + cache.set(key, value) + assert cache.pop(key) == value + assert not cache.get(key) + assert cache.pop(key, default="a") == "a" + + +@pytest.mark.asyncio +async def test_cache_coro(cache): + called = 0 + + async def test(): + nonlocal called + called += 1 + return called + + await cache.save_result(test, key="test") + result = await cache.save_result(test, key="test") + assert result == called == 1