add cache utility (#1790)

* add simple caching utility

* test cache

* remove prefix, default on get

* check expiry in pop aswell

* remove unnecessary type

* improve invalidation task

increase default interval to 10 seconds - doesnt have to check that often.
instead of recreating the dict everytime mutate the existing one
This commit is contained in:
jackstar12 2023-08-02 14:13:31 +02:00 committed by GitHub
parent 9bc8a9db55
commit 2577ce7f81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 0 deletions

View File

@ -21,6 +21,7 @@ from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.responses import JSONResponse
from lnbits.cache import cache
from lnbits.core.crud import get_installed_extensions
from lnbits.core.helpers import migrate_extension_database
from lnbits.core.services import websocketUpdater
@ -330,6 +331,8 @@ def register_startup(app: FastAPI):
if settings.lnbits_admin_ui:
initialize_server_logger()
asyncio.create_task(cache.invalidate_forever())
except Exception as e:
logger.error(str(e))
raise ImportError("Failed to run 'startup' event.")

65
lnbits/cache.py Normal file
View File

@ -0,0 +1,65 @@
from __future__ import annotations
import asyncio
from time import time
from typing import Any, NamedTuple, Optional
from loguru import logger
class Cached(NamedTuple):
value: Any
expiry: float
class Cache:
"""
Small caching utility providing simple get/set interface (very much like redis)
"""
def __init__(self):
self._values: dict[Any, Cached] = {}
def get(self, key: str, default=None) -> Optional[Any]:
cached = self._values.get(key)
if cached is not None:
if cached.expiry > time():
return cached.value
else:
self._values.pop(key)
return default
def set(self, key: str, value: Any, expiry: float = 10):
self._values[key] = Cached(value, time() + expiry)
def pop(self, key: str, default=None) -> Optional[Any]:
cached = self._values.pop(key, None)
if cached and cached.expiry > time():
return cached.value
return default
async def save_result(self, coro, key: str, expiry: float = 10):
"""
If `key` exists, return its value, otherwise call coro and cache its result
"""
cached = self.get(key)
if cached:
return cached
else:
value = await coro()
self.set(key, value, expiry=expiry)
return value
async def invalidate_forever(self, interval: float = 10):
while True:
try:
await asyncio.sleep(interval)
ts = time()
expired = [k for k, v in self._values.items() if v.expiry < ts]
for k in expired:
self._values.pop(k)
except Exception:
logger.error("Error invalidating cache")
cache = Cache()

56
tests/core/test_cache.py Normal file
View File

@ -0,0 +1,56 @@
import asyncio
import pytest
from lnbits.cache import Cache
from tests.conftest import pytest_asyncio
@pytest_asyncio.fixture(scope="session")
async def cache():
cache = Cache()
task = asyncio.create_task(cache.invalidate_forever(interval=0.1))
yield cache
task.cancel()
key = "foo"
value = "bar"
@pytest.mark.asyncio
async def test_cache_get_set(cache):
cache.set(key, value)
assert cache.get(key) == value
assert cache.get(key, default="default") == value
assert cache.get("i-dont-exist", default="default") == "default"
@pytest.mark.asyncio
async def test_cache_expiry(cache):
cache.set(key, value, expiry=0.1)
await asyncio.sleep(0.2)
assert not cache.get(key)
@pytest.mark.asyncio
async def test_cache_pop(cache):
cache.set(key, value)
assert cache.pop(key) == value
assert not cache.get(key)
assert cache.pop(key, default="a") == "a"
@pytest.mark.asyncio
async def test_cache_coro(cache):
called = 0
async def test():
nonlocal called
called += 1
return called
await cache.save_result(test, key="test")
result = await cache.save_result(test, key="test")
assert result == called == 1