diff --git a/lnbits/app.py b/lnbits/app.py index 67d2e181f..8cab32c87 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -24,7 +24,9 @@ from lnbits.core.crud import ( from lnbits.core.helpers import migrate_extension_database from lnbits.core.services.extensions import deactivate_extension, get_valid_extensions from lnbits.core.tasks import ( # watchdog_task + audit_queue, killswitch_task, + wait_for_audit_data, wait_for_paid_invoices, ) from lnbits.exceptions import register_exception_handlers @@ -49,6 +51,7 @@ from .core.db import core_app_extra from .core.models.extensions import Extension, ExtensionMeta, InstallableExtension from .core.services import check_admin_settings, check_webpush_settings from .middleware import ( + AuditMiddleware, CustomGZipMiddleware, ExtensionsRedirectMiddleware, InstalledExtensionMiddleware, @@ -149,6 +152,8 @@ def create_app() -> FastAPI: CustomGZipMiddleware, minimum_size=1000, exclude_paths=["/api/v1/payments/sse"] ) + app.add_middleware(AuditMiddleware, audit_queue=audit_queue) + # required for SSO login app.add_middleware(SessionMiddleware, secret_key=settings.auth_secret_key) @@ -414,6 +419,7 @@ def register_async_tasks(app: FastAPI): if not settings.lnbits_extensions_deactivate_all: create_task(check_and_register_extensions(app)) + create_permanent_task(wait_for_audit_data) create_permanent_task(check_pending_payments) create_permanent_task(invoice_listener) create_permanent_task(internal_invoice_listener) diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index a229ff6bf..5e2b39805 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -19,6 +19,7 @@ from lnbits.settings import get_funding_source, settings from lnbits.tasks import send_push_notification api_invoice_listeners: Dict[str, asyncio.Queue] = {} +audit_queue: asyncio.Queue = asyncio.Queue() async def killswitch_task(): @@ -157,3 +158,12 @@ async def send_payment_push_notification(payment: Payment): f"https://{subscription.host}/wallet?usr={wallet.user}&wal={wallet.id}" ) await send_push_notification(subscription, title, body, url) + + +async def wait_for_audit_data(): + """ + . + """ + while settings.lnbits_running: + data: dict = await audit_queue.get() + print("### data", data) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index ba66dc8b1..4f8ed1150 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -90,6 +90,7 @@ class KeyChecker(SecurityBase): detail="Wallet not found.", ) + request.scope["user_id"] = wallet.user if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value: raise HTTPException( status_code=HTTPStatus.UNAUTHORIZED, @@ -148,6 +149,7 @@ async def check_user_exists( if not account: raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") + r.scope["user_id"] = account.id if not settings.is_user_allowed(account.id): raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not allowed.") diff --git a/lnbits/middleware.py b/lnbits/middleware.py index 611995a06..7a643f371 100644 --- a/lnbits/middleware.py +++ b/lnbits/middleware.py @@ -1,11 +1,15 @@ +import asyncio +from datetime import datetime, timezone from http import HTTPStatus -from typing import Any, List, Union +from typing import Any, List, Optional, Union -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Response from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse +from loguru import logger from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.gzip import GZipMiddleware from starlette.types import ASGIApp, Receive, Scope, Send @@ -120,6 +124,51 @@ class ExtensionsRedirectMiddleware: await self.app(scope, receive, send) +class AuditMiddleware(BaseHTTPMiddleware): + + def __init__(self, app: ASGIApp, audit_queue: asyncio.Queue) -> None: + super().__init__(app) + self.audit_queue = audit_queue + # delete_time purge after X days + # time, # include pats, exclude paths (regex) + + async def dispatch(self, request: Request, call_next) -> Response: + start_time = datetime.now(timezone.utc) + response: Optional[Response] = None + try: + response = await call_next(request) + assert response + return response + finally: + duration = (datetime.now(timezone.utc) - start_time).total_seconds() + await self._log_audit(request, response, duration) + + async def _log_audit( + self, request: Request, response: Optional[Response], duration: float + ): + try: + http_method = request.scope.get("method", None) + path = request.scope.get("path", None) + response_code = str(response.status_code) if response else None + if not settings.is_http_request_auditable(http_method, path, response_code): + print("### NOT", http_method, path, response_code) + return None + data = { + "ip": request.client.host if request.client else None, + "user_id": request.scope.get("user_id", None), + "path": path, + "route_path": getattr(request.scope.get("route", {}), "path", None), + "request_type": request.scope.get("type", None), + "request_method": http_method, + "query_string": request.scope.get("query_string", None), + "response_code": response_code, + "duration": duration, + } + await self.audit_queue.put(data) + except Exception as ex: + logger.warning(ex) + + def add_ratelimit_middleware(app: FastAPI): core_app_extra.register_new_ratelimiter() # latest https://slowapi.readthedocs.io/en/latest/ diff --git a/lnbits/settings.py b/lnbits/settings.py index d951505d7..080fd2a2f 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -4,6 +4,7 @@ import importlib import importlib.metadata import inspect import json +import re from enum import Enum from hashlib import sha256 from os import path @@ -26,6 +27,7 @@ def list_parse_fallback(v: str): return [] + class LNbitsSettings(BaseModel): @classmethod def validate_list(cls, val): @@ -509,6 +511,70 @@ class KeycloakAuthSettings(LNbitsSettings): keycloak_client_secret: str = Field(default="") +class AuditSettings(LNbitsSettings): + lnbits_audit_enabled: bool = Field(default=True) + + # If true the client IP address will be loged + lnbits_audit_log_ip: bool = Field(default=False) + + # List of paths to be included (regex match). Empty list means all. + lnbits_audit_include_paths: list[str] = Field(default=[".*api/v1/.*"]) + # List of paths to be excluded (regex match). Empty list means none. + lnbits_audit_exclude_paths: list[str] = Field( + default=["/static", "service-worker.js"] + ) + + # List of HTTP methods to be included. Empty lists means all. + # GET, POST, PUT, PATCH, DELETE, HEAD, OPTIONS + lnbits_audit_http_methods: list[str] = Field(default=[]) + + # List of HTTP codes to be included (regex match). Empty lists means all. + lnbits_audit_http_response_codes: list[str] = Field(default=[]) + + def is_http_request_auditable( + self, + http_method: Optional[str], + path: Optional[str], + http_response_code: Optional[str], + ) -> bool: + if not self.lnbits_audit_enabled: + return False + if len(self.lnbits_audit_http_methods) != 0: + if not http_method or http_method not in self.lnbits_audit_http_methods: + return False + + if not self._is_http_request_path_auditable(path): + return False + + if len(self.lnbits_audit_http_response_codes) != 0: + is_response_code_included = True + if not http_response_code: + return False + for response_code in self.lnbits_audit_http_response_codes: + if _re_fullmatch_safe(response_code, http_response_code): + is_response_code_included = True + break + if not is_response_code_included: + return False + + return True + + def _is_http_request_path_auditable(self, path: Optional[str]): + if len(self.lnbits_audit_exclude_paths) != 0 and path: + for exclude_path in self.lnbits_audit_exclude_paths: + if _re_fullmatch_safe(exclude_path, path): + return False + + if len(self.lnbits_audit_include_paths) != 0: + if not path: + return False + for include_path in self.lnbits_audit_include_paths: + if _re_fullmatch_safe(include_path, path): + return True + + return False + + class EditableSettings( UsersSettings, ExtensionsSettings, @@ -520,6 +586,7 @@ class EditableSettings( LightningSettings, WebPushSettings, NodeUISettings, + AuditSettings, AuthSettings, NostrAuthSettings, GoogleAuthSettings, @@ -701,6 +768,14 @@ class SettingsField(BaseModel): value: Optional[Any] tag: str = "core" +def _re_fullmatch_safe(pattern: str, string: str): + try: + return re.fullmatch(pattern, string) is not None + except Exception as _: + logger.warning(f"Regex error for pattern {pattern}") + return False + + def set_cli_settings(**kwargs): for key, value in kwargs.items():