diff --git a/.gitignore b/.gitignore index 9e661381a..d1d668b45 100644 --- a/.gitignore +++ b/.gitignore @@ -49,8 +49,8 @@ fly.toml lnbits-backup.zip # Ignore extensions (post installable extension PR) -extensions -upgrades/ +/extensions +/upgrades/ # builded python package dist diff --git a/lnbits/app.py b/lnbits/app.py index 3854c173c..c649c77c1 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -17,10 +17,13 @@ from slowapi.util import get_remote_address from starlette.middleware.sessions import SessionMiddleware from lnbits.core.crud import ( + add_installed_extension, get_dbversions, get_installed_extensions, update_installed_extension_state, ) +from lnbits.core.extensions.extension_manager import deactivate_extension +from lnbits.core.extensions.helpers import version_parse from lnbits.core.helpers import migrate_extension_database from lnbits.core.tasks import ( # watchdog_task killswitch_task, @@ -44,14 +47,8 @@ from lnbits.wallets import get_funding_source, set_funding_source from .commands import migrate_databases from .core import init_core_routers from .core.db import core_app_extra +from .core.extensions.models import Extension, InstallableExtension from .core.services import check_admin_settings, check_webpush_settings -from .core.views.extension_api import add_installed_extension -from .extension_manager import ( - Extension, - InstallableExtension, - get_valid_extensions, - version_parse, -) from .middleware import ( CustomGZipMiddleware, ExtensionsRedirectMiddleware, @@ -243,6 +240,7 @@ async def check_installed_extensions(app: FastAPI): ) except Exception as e: logger.warning(e) + await deactivate_extension(ext.id) logger.warning( f"Failed to re-install extension: {ext.id} ({ext.installed_version})" ) @@ -317,7 +315,6 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension): # mount routes for the new version core_app_extra.register_new_ext_routes(extension) - ext.notify_upgrade(extension.upgrade_hash) def register_custom_extensions_path(): @@ -380,24 +377,22 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None: ) app.mount(s["path"], StaticFiles(directory=static_dir), s["name"]) - if hasattr(ext_module, f"{ext.code}_redirect_paths"): - ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths") - settings.lnbits_extensions_redirects = [ - r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code - ] - for r in ext_redirects: - r["ext_id"] = ext.code - settings.lnbits_extensions_redirects.append(r) + ext_redirects = ( + getattr(ext_module, f"{ext.code}_redirect_paths") + if hasattr(ext_module, f"{ext.code}_redirect_paths") + else [] + ) - logger.trace(f"adding route for extension {ext_module}") + settings.activate_extension_paths(ext.code, ext.upgrade_hash, ext_redirects) + logger.trace(f"Adding route for extension {ext_module}.") prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else "" app.include_router(router=ext_route, prefix=prefix) async def check_and_register_extensions(app: FastAPI): await check_installed_extensions(app) - for ext in get_valid_extensions(False): + for ext in Extension.get_valid_extensions(False): try: register_ext_routes(app, ext) except Exception as exc: diff --git a/lnbits/commands.py b/lnbits/commands.py index bb4928639..bf9a1b936 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -25,18 +25,18 @@ from lnbits.core.crud import ( remove_deleted_wallets, update_payment_status, ) +from lnbits.core.extensions.models import ( + CreateExtension, + ExtensionRelease, + InstallableExtension, +) from lnbits.core.helpers import migrate_databases -from lnbits.core.models import Payment, PaymentState, User +from lnbits.core.models import Payment, PaymentState from lnbits.core.services import check_admin_settings from lnbits.core.views.extension_api import ( api_install_extension, api_uninstall_extension, ) -from lnbits.extension_manager import ( - CreateExtension, - ExtensionRelease, - InstallableExtension, -) from lnbits.settings import settings from lnbits.wallets.base import Wallet @@ -611,7 +611,7 @@ async def _call_install_extension( ) resp.raise_for_status() else: - await api_install_extension(data, User(id="mock_id")) + await api_install_extension(data) async def _call_uninstall_extension( @@ -625,7 +625,7 @@ async def _call_uninstall_extension( ) resp.raise_for_status() else: - await api_uninstall_extension(extension, User(id="mock_id")) + await api_uninstall_extension(extension) async def _can_run_operation(url) -> bool: diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 39fbad93b..3de0345b4 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -8,14 +8,14 @@ import shortuuid from passlib.context import CryptContext from lnbits.core.db import db -from lnbits.core.models import PaymentState -from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page -from lnbits.extension_manager import ( +from lnbits.core.extensions.models import ( InstallableExtension, PayToEnableInfo, UserExtension, UserExtensionInfo, ) +from lnbits.core.models import PaymentState +from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page from lnbits.settings import ( AdminSettings, EditableSettings, @@ -430,7 +430,7 @@ async def get_installed_extension( async def get_installed_extensions( active: Optional[bool] = None, conn: Optional[Connection] = None, -) -> List["InstallableExtension"]: +) -> List[InstallableExtension]: rows = await (conn or db).fetchall( "SELECT * FROM installed_extensions", (), diff --git a/lnbits/core/extensions/extension_manager.py b/lnbits/core/extensions/extension_manager.py new file mode 100644 index 000000000..168104175 --- /dev/null +++ b/lnbits/core/extensions/extension_manager.py @@ -0,0 +1,93 @@ +import asyncio +import importlib + +from loguru import logger + +from lnbits.core.crud import ( + add_installed_extension, + delete_installed_extension, + get_dbversions, + get_installed_extension, + update_installed_extension_state, +) +from lnbits.core.db import core_app_extra +from lnbits.core.helpers import migrate_extension_database +from lnbits.settings import settings + +from .models import Extension, InstallableExtension + + +async def install_extension(ext_info: InstallableExtension) -> Extension: + extension = Extension.from_installable_ext(ext_info) + installed_ext = await get_installed_extension(ext_info.id) + ext_info.payments = installed_ext.payments if installed_ext else [] + + await ext_info.download_archive() + + ext_info.extract_archive() + + db_version = (await get_dbversions()).get(ext_info.id, 0) + await migrate_extension_database(extension, db_version) + + await add_installed_extension(ext_info) + + if extension.is_upgrade_extension: + # call stop while the old routes are still active + await stop_extension_background_work(ext_info.id) + + return extension + + +async def uninstall_extension(ext_id: str): + await stop_extension_background_work(ext_id) + + settings.deactivate_extension_paths(ext_id) + + extension = await get_installed_extension(ext_id) + if extension: + extension.clean_extension_files() + await delete_installed_extension(ext_id=ext_id) + + +async def activate_extension(ext: Extension): + core_app_extra.register_new_ext_routes(ext) + await update_installed_extension_state(ext_id=ext.code, active=True) + + +async def deactivate_extension(ext_id: str): + settings.deactivate_extension_paths(ext_id) + await update_installed_extension_state(ext_id=ext_id, active=False) + + +async def stop_extension_background_work(ext_id: str) -> bool: + """ + Stop background work for extension (like asyncio.Tasks, WebSockets, etc). + Extensions SHOULD expose a `api_stop()` function. + """ + upgrade_hash = settings.lnbits_upgraded_extensions.get(ext_id, "") + ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash) + + try: + logger.info(f"Stopping background work for extension '{ext.module_name}'.") + old_module = importlib.import_module(ext.module_name) + + # Extensions must expose an `{ext_id}_stop()` function at the module level + # The `api_stop()` function is for backwards compatibility (will be deprecated) + stop_fns = [f"{ext_id}_stop", "api_stop"] + stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None) + assert stop_fn_name, "No stop function found for '{ext.module_name}'" + + stop_fn = getattr(old_module, stop_fn_name) + if stop_fn: + if asyncio.iscoroutinefunction(stop_fn): + await stop_fn() + else: + stop_fn() + + logger.info(f"Stopped background work for extension '{ext.module_name}'.") + except Exception as ex: + logger.warning(f"Failed to stop background work for '{ext.module_name}'.") + logger.warning(ex) + return False + + return True diff --git a/lnbits/core/extensions/helpers.py b/lnbits/core/extensions/helpers.py new file mode 100644 index 000000000..645b1720f --- /dev/null +++ b/lnbits/core/extensions/helpers.py @@ -0,0 +1,56 @@ +import hashlib +from typing import Any, Optional +from urllib import request + +import httpx +from loguru import logger +from packaging import version + +from lnbits.settings import settings + + +def version_parse(v: str): + """ + Wrapper for version.parse() that does not throw if the version is invalid. + Instead it return the lowest possible version ("0.0.0") + """ + try: + return version.parse(v) + except Exception: + return version.parse("0.0.0") + + +async def github_api_get(url: str, error_msg: Optional[str]) -> Any: + headers = {"User-Agent": settings.user_agent} + if settings.lnbits_ext_github_token: + headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}" + async with httpx.AsyncClient(headers=headers) as client: + resp = await client.get(url) + if resp.status_code != 200: + logger.warning(f"{error_msg} ({url}): {resp.text}") + resp.raise_for_status() + return resp.json() + + +def download_url(url, save_path): + with request.urlopen(url, timeout=60) as dl_file: + with open(save_path, "wb") as out_file: + out_file.write(dl_file.read()) + + +def file_hash(filename): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() + + +def icon_to_github_url(source_repo: str, path: Optional[str]) -> str: + if not path: + return "" + _, _, *rest = path.split("/") + tail = "/".join(rest) + return f"https://github.com/{source_repo}/raw/main/{tail}" diff --git a/lnbits/extension_manager.py b/lnbits/core/extensions/models.py similarity index 72% rename from lnbits/extension_manager.py rename to lnbits/core/extensions/models.py index ff11f4966..8273c8deb 100644 --- a/lnbits/extension_manager.py +++ b/lnbits/core/extensions/models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import hashlib import json @@ -6,16 +8,22 @@ import shutil import sys import zipfile from pathlib import Path -from typing import Any, List, NamedTuple, Optional, Tuple -from urllib import request +from typing import Any, NamedTuple, Optional import httpx from loguru import logger -from packaging import version from pydantic import BaseModel from lnbits.settings import settings +from .helpers import ( + download_url, + file_hash, + github_api_get, + icon_to_github_url, + version_parse, +) + class ExplicitRelease(BaseModel): id: str @@ -23,7 +31,7 @@ class ExplicitRelease(BaseModel): version: str archive: str hash: str - dependencies: List[str] = [] + dependencies: list[str] = [] repo: Optional[str] icon: Optional[str] short_description: Optional[str] @@ -48,9 +56,9 @@ class GitHubRelease(BaseModel): class Manifest(BaseModel): - featured: List[str] = [] - extensions: List["ExplicitRelease"] = [] - repos: List["GitHubRelease"] = [] + featured: list[str] = [] + extensions: list[ExplicitRelease] = [] + repos: list[GitHubRelease] = [] class GitHubRepoRelease(BaseModel): @@ -81,6 +89,17 @@ class ExtensionConfig(BaseModel): return True return version_parse(self.min_lnbits_version) <= version_parse(settings.version) + @classmethod + async def fetch_github_release_config( + cls, org: str, repo: str, tag_name: str + ) -> Optional[ExtensionConfig]: + config_url = ( + f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json" + ) + error_msg = "Cannot fetch GitHub extension config" + config = await github_api_get(config_url, error_msg) + return ExtensionConfig.parse_obj(config) + class ReleasePaymentInfo(BaseModel): amount: Optional[int] = None @@ -112,7 +131,7 @@ class UserExtension(BaseModel): return self.extra.paid_to_enable is True @classmethod - def from_row(cls, data: dict) -> "UserExtension": + def from_row(cls, data: dict) -> UserExtension: ext = UserExtension(**data) ext.extra = ( UserExtensionInfo(**json.loads(data["_extra"] or "{}")) @@ -122,124 +141,6 @@ class UserExtension(BaseModel): return ext -def download_url(url, save_path): - with request.urlopen(url, timeout=60) as dl_file: - with open(save_path, "wb") as out_file: - out_file.write(dl_file.read()) - - -def file_hash(filename): - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(filename, "rb", buffering=0) as f: - while n := f.readinto(mv): - h.update(mv[:n]) - return h.hexdigest() - - -async def fetch_github_repo_info( - org: str, repository: str -) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]: - repo_url = f"https://api.github.com/repos/{org}/{repository}" - error_msg = "Cannot fetch extension repo" - repo = await github_api_get(repo_url, error_msg) - github_repo = GitHubRepo.parse_obj(repo) - - lates_release_url = ( - f"https://api.github.com/repos/{org}/{repository}/releases/latest" - ) - error_msg = "Cannot fetch extension releases" - latest_release: Any = await github_api_get(lates_release_url, error_msg) - - config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json" - error_msg = "Cannot fetch config for extension" - config = await github_api_get(config_url, error_msg) - - return ( - github_repo, - GitHubRepoRelease.parse_obj(latest_release), - ExtensionConfig.parse_obj(config), - ) - - -async def fetch_manifest(url) -> Manifest: - error_msg = "Cannot fetch extensions manifest" - manifest = await github_api_get(url, error_msg) - return Manifest.parse_obj(manifest) - - -async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]: - releases_url = f"https://api.github.com/repos/{org}/{repo}/releases" - error_msg = "Cannot fetch extension releases" - releases = await github_api_get(releases_url, error_msg) - return [GitHubRepoRelease.parse_obj(r) for r in releases] - - -async def fetch_github_release_config( - org: str, repo: str, tag_name: str -) -> Optional[ExtensionConfig]: - config_url = ( - f"https://raw.githubusercontent.com/{org}/{repo}/{tag_name}/config.json" - ) - error_msg = "Cannot fetch GitHub extension config" - config = await github_api_get(config_url, error_msg) - return ExtensionConfig.parse_obj(config) - - -async def github_api_get(url: str, error_msg: Optional[str]) -> Any: - headers = {"User-Agent": settings.user_agent} - if settings.lnbits_ext_github_token: - headers["Authorization"] = f"Bearer {settings.lnbits_ext_github_token}" - async with httpx.AsyncClient(headers=headers) as client: - resp = await client.get(url) - if resp.status_code != 200: - logger.warning(f"{error_msg} ({url}): {resp.text}") - resp.raise_for_status() - return resp.json() - - -async def fetch_release_payment_info( - url: str, amount: Optional[int] = None -) -> Optional[ReleasePaymentInfo]: - if amount: - url = f"{url}?amount={amount}" - try: - async with httpx.AsyncClient() as client: - resp = await client.get(url) - resp.raise_for_status() - return ReleasePaymentInfo(**resp.json()) - except Exception as e: - logger.warning(e) - return None - - -async def fetch_release_details(details_link: str) -> Optional[dict]: - - try: - async with httpx.AsyncClient() as client: - resp = await client.get(details_link) - resp.raise_for_status() - data = resp.json() - if "description_md" in data: - resp = await client.get(data["description_md"]) - if not resp.is_error: - data["description_md"] = resp.text - - return data - except Exception as e: - logger.warning(e) - return None - - -def icon_to_github_url(source_repo: str, path: Optional[str]) -> str: - if not path: - return "" - _, _, *rest = path.split("/") - tail = "/".join(rest) - return f"https://github.com/{source_repo}/raw/main/{tail}" - - class Extension(NamedTuple): code: str is_valid: bool @@ -247,7 +148,7 @@ class Extension(NamedTuple): name: Optional[str] = None short_description: Optional[str] = None tile: Optional[str] = None - contributors: Optional[List[str]] = None + contributors: Optional[list[str]] = None hidden: bool = False migration_module: Optional[str] = None db_name: Optional[str] = None @@ -269,7 +170,7 @@ class Extension(NamedTuple): return self.upgrade_hash != "" @classmethod - def from_installable_ext(cls, ext_info: "InstallableExtension") -> "Extension": + def from_installable_ext(cls, ext_info: InstallableExtension) -> Extension: return Extension( code=ext_info.id, is_valid=True, @@ -278,22 +179,43 @@ class Extension(NamedTuple): upgrade_hash=ext_info.hash if ext_info.module_installed else "", ) + @classmethod + def get_valid_extensions( + cls, include_deactivated: Optional[bool] = True + ) -> list[Extension]: + valid_extensions = [ + extension for extension in cls._extensions() if extension.is_valid + ] -# All subdirectories in the current directory, not recursive. + if include_deactivated: + return valid_extensions + if settings.lnbits_extensions_deactivate_all: + return [] -class ExtensionManager: - def __init__(self) -> None: + return [ + e + for e in valid_extensions + if e.code not in settings.lnbits_deactivated_extensions + ] + + @classmethod + def get_valid_extension( + cls, ext_id: str, include_deactivated: Optional[bool] = True + ) -> Optional[Extension]: + all_extensions = cls.get_valid_extensions(include_deactivated) + return next((e for e in all_extensions if e.code == ext_id), None) + + @classmethod + def _extensions(cls) -> list[Extension]: p = Path(settings.lnbits_extensions_path, "extensions") Path(p).mkdir(parents=True, exist_ok=True) - self._extension_folders: List[Path] = [f for f in p.iterdir() if f.is_dir()] + extension_folders: list[Path] = [f for f in p.iterdir() if f.is_dir()] - @property - def extensions(self) -> List[Extension]: # todo: remove this property somehow, it is too expensive - output: List[Extension] = [] + output: list[Extension] = [] - for extension_folder in self._extension_folders: + for extension_folder in extension_folders: extension_code = extension_folder.parts[-1] try: with open(extension_folder / "config.json") as json_file: @@ -356,13 +278,27 @@ class ExtensionRelease(BaseModel): if not self.pay_link: return - payment_info = await fetch_release_payment_info(self.pay_link) + payment_info = await self.fetch_release_payment_info() self.cost_sats = payment_info.amount if payment_info else None + async def fetch_release_payment_info( + self, amount: Optional[int] = None + ) -> Optional[ReleasePaymentInfo]: + url = f"{self.pay_link}?amount={amount}" if amount else self.pay_link + assert url, "Missing URL for payment info." + try: + async with httpx.AsyncClient() as client: + resp = await client.get(url) + resp.raise_for_status() + return ReleasePaymentInfo(**resp.json()) + except Exception as e: + logger.warning(e) + return None + @classmethod def from_github_release( - cls, source_repo: str, r: "GitHubRepoRelease" - ) -> "ExtensionRelease": + cls, source_repo: str, r: GitHubRepoRelease + ) -> ExtensionRelease: return ExtensionRelease( name=r.name, description=r.name, @@ -377,8 +313,8 @@ class ExtensionRelease(BaseModel): @classmethod def from_explicit_release( - cls, source_repo: str, e: "ExplicitRelease" - ) -> "ExtensionRelease": + cls, source_repo: str, e: ExplicitRelease + ) -> ExtensionRelease: return ExtensionRelease( name=e.name, version=e.version, @@ -397,9 +333,9 @@ class ExtensionRelease(BaseModel): ) @classmethod - async def get_github_releases(cls, org: str, repo: str) -> List["ExtensionRelease"]: + async def get_github_releases(cls, org: str, repo: str) -> list[ExtensionRelease]: try: - github_releases = await fetch_github_releases(org, repo) + github_releases = await cls.fetch_github_releases(org, repo) return [ ExtensionRelease.from_github_release(f"{org}/{repo}", r) for r in github_releases @@ -408,6 +344,33 @@ class ExtensionRelease(BaseModel): logger.warning(e) return [] + @classmethod + async def fetch_github_releases( + cls, org: str, repo: str + ) -> list[GitHubRepoRelease]: + releases_url = f"https://api.github.com/repos/{org}/{repo}/releases" + error_msg = "Cannot fetch extension releases" + releases = await github_api_get(releases_url, error_msg) + return [GitHubRepoRelease.parse_obj(r) for r in releases] + + @classmethod + async def fetch_release_details(cls, details_link: str) -> Optional[dict]: + + try: + async with httpx.AsyncClient() as client: + resp = await client.get(details_link) + resp.raise_for_status() + data = resp.json() + if "description_md" in data: + resp = await client.get(data["description_md"]) + if not resp.is_error: + data["description_md"] = resp.text + + return data + except Exception as e: + logger.warning(e) + return None + class InstallableExtension(BaseModel): id: str @@ -415,13 +378,13 @@ class InstallableExtension(BaseModel): active: Optional[bool] = False short_description: Optional[str] = None icon: Optional[str] = None - dependencies: List[str] = [] + dependencies: list[str] = [] is_admin_only: bool = False stars: int = 0 featured = False latest_release: Optional[ExtensionRelease] = None installed_release: Optional[ExtensionRelease] = None - payments: List[ReleasePaymentInfo] = [] + payments: list[ReleasePaymentInfo] = [] pay_to_enable: Optional[PayToEnableInfo] = None archive: Optional[str] = None @@ -546,16 +509,6 @@ class InstallableExtension(BaseModel): shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir)) logger.success(f"Extension {self.name} ({self.installed_version}) installed.") - def notify_upgrade(self, upgrade_hash: Optional[str]) -> None: - """ - Update the list of upgraded extensions. The middleware will perform - redirects based on this - """ - if upgrade_hash: - settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}") - - settings.lnbits_all_extensions_ids.add(self.id) - def clean_extension_files(self): # remove downloaded archive if self.zip_path.is_file(): @@ -610,7 +563,7 @@ class InstallableExtension(BaseModel): self.payments.append(payment_info) @classmethod - def from_row(cls, data: dict) -> "InstallableExtension": + def from_row(cls, data: dict) -> InstallableExtension: meta = json.loads(data["meta"]) ext = InstallableExtension(**data) if "installed_release" in meta: @@ -623,9 +576,7 @@ class InstallableExtension(BaseModel): return ext @classmethod - def from_rows( - cls, rows: Optional[List[Any]] = None - ) -> List["InstallableExtension"]: + def from_rows(cls, rows: Optional[list[Any]] = None) -> list[InstallableExtension]: if rows is None: rows = [] return [InstallableExtension.from_row(row) for row in rows] @@ -633,9 +584,9 @@ class InstallableExtension(BaseModel): @classmethod async def from_github_release( cls, github_release: GitHubRelease - ) -> Optional["InstallableExtension"]: + ) -> Optional[InstallableExtension]: try: - repo, latest_release, config = await fetch_github_repo_info( + repo, latest_release, config = await cls.fetch_github_repo_info( github_release.organisation, github_release.repository ) source_repo = f"{github_release.organisation}/{github_release.repository}" @@ -657,7 +608,7 @@ class InstallableExtension(BaseModel): return None @classmethod - def from_explicit_release(cls, e: ExplicitRelease) -> "InstallableExtension": + def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension: return InstallableExtension( id=e.id, name=e.name, @@ -670,13 +621,13 @@ class InstallableExtension(BaseModel): @classmethod async def get_installable_extensions( cls, - ) -> List["InstallableExtension"]: - extension_list: List[InstallableExtension] = [] - extension_id_list: List[str] = [] + ) -> list[InstallableExtension]: + extension_list: list[InstallableExtension] = [] + extension_id_list: list[str] = [] for url in settings.lnbits_extensions_manifests: try: - manifest = await fetch_manifest(url) + manifest = await cls.fetch_manifest(url) for r in manifest.repos: ext = await InstallableExtension.from_github_release(r) @@ -712,12 +663,12 @@ class InstallableExtension(BaseModel): return extension_list @classmethod - async def get_extension_releases(cls, ext_id: str) -> List["ExtensionRelease"]: - extension_releases: List[ExtensionRelease] = [] + async def get_extension_releases(cls, ext_id: str) -> list[ExtensionRelease]: + extension_releases: list[ExtensionRelease] = [] for url in settings.lnbits_extensions_manifests: try: - manifest = await fetch_manifest(url) + manifest = await cls.fetch_manifest(url) for r in manifest.repos: if r.id != ext_id: continue @@ -741,8 +692,8 @@ class InstallableExtension(BaseModel): @classmethod async def get_extension_release( cls, ext_id: str, source_repo: str, archive: str, version: str - ) -> Optional["ExtensionRelease"]: - all_releases: List[ExtensionRelease] = ( + ) -> Optional[ExtensionRelease]: + all_releases: list[ExtensionRelease] = ( await InstallableExtension.get_extension_releases(ext_id) ) selected_release = [ @@ -755,6 +706,37 @@ class InstallableExtension(BaseModel): return selected_release[0] if len(selected_release) != 0 else None + @classmethod + async def fetch_github_repo_info( + cls, org: str, repository: str + ) -> tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]: + repo_url = f"https://api.github.com/repos/{org}/{repository}" + error_msg = "Cannot fetch extension repo" + repo = await github_api_get(repo_url, error_msg) + github_repo = GitHubRepo.parse_obj(repo) + + lates_release_url = ( + f"https://api.github.com/repos/{org}/{repository}/releases/latest" + ) + error_msg = "Cannot fetch extension releases" + latest_release: Any = await github_api_get(lates_release_url, error_msg) + + config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json" + error_msg = "Cannot fetch config for extension" + config = await github_api_get(config_url, error_msg) + + return ( + github_repo, + GitHubRepoRelease.parse_obj(latest_release), + ExtensionConfig.parse_obj(config), + ) + + @classmethod + async def fetch_manifest(cls, url) -> Manifest: + error_msg = "Cannot fetch extensions manifest" + manifest = await github_api_get(url, error_msg) + return Manifest.parse_obj(manifest) + class CreateExtension(BaseModel): ext_id: str @@ -769,32 +751,3 @@ class ExtensionDetailsRequest(BaseModel): ext_id: str source_repo: str version: str - - -def get_valid_extensions(include_deactivated: Optional[bool] = True) -> List[Extension]: - valid_extensions = [ - extension for extension in ExtensionManager().extensions if extension.is_valid - ] - - if include_deactivated: - return valid_extensions - - if settings.lnbits_extensions_deactivate_all: - return [] - - return [ - e - for e in valid_extensions - if e.code not in settings.lnbits_deactivated_extensions - ] - - -def version_parse(v: str): - """ - Wrapper for version.parse() that does not throw if the version is invalid. - Instead it return the lowest possible version ("0.0.0") - """ - try: - return version.parse(v) - except Exception: - return version.parse("0.0.0") diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py index b35169381..d79763a0e 100644 --- a/lnbits/core/helpers.py +++ b/lnbits/core/helpers.py @@ -1,9 +1,8 @@ import importlib import re -from typing import Any, Optional +from typing import Any from uuid import UUID -import httpx from loguru import logger from lnbits.core import migrations as core_migrations @@ -13,11 +12,10 @@ from lnbits.core.crud import ( update_migration_version, ) from lnbits.core.db import db as core_db -from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection -from lnbits.extension_manager import ( +from lnbits.core.extensions.models import ( Extension, - get_valid_extensions, ) +from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection from lnbits.settings import settings @@ -55,68 +53,6 @@ async def run_migration( await update_migration_version(conn, db_name, version) -async def stop_extension_background_work( - ext_id: str, user: str, access_token: Optional[str] = None -): - """ - Stop background work for extension (like asyncio.Tasks, WebSockets, etc). - Extensions SHOULD expose a `api_stop()` function and/or a DELETE enpoint - at the root level of their API. - """ - stopped = await _stop_extension_background_work(ext_id) - - if not stopped: - # fallback to REST API call - await _stop_extension_background_work_via_api(ext_id, user, access_token) - - -async def _stop_extension_background_work(ext_id) -> bool: - upgrade_hash = settings.extension_upgrade_hash(ext_id) or "" - ext = Extension(ext_id, True, False, upgrade_hash=upgrade_hash) - - try: - logger.info(f"Stopping background work for extension '{ext.module_name}'.") - old_module = importlib.import_module(ext.module_name) - - # Extensions must expose an `{ext_id}_stop()` function at the module level - # The `api_stop()` function is for backwards compatibility (will be deprecated) - stop_fns = [f"{ext_id}_stop", "api_stop"] - stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None) - assert stop_fn_name, "No stop function found for '{ext.module_name}'" - - stop_fn = getattr(old_module, stop_fn_name) - if stop_fn: - await stop_fn() - - logger.info(f"Stopped background work for extension '{ext.module_name}'.") - except Exception as ex: - logger.warning(f"Failed to stop background work for '{ext.module_name}'.") - logger.warning(ex) - return False - - return True - - -async def _stop_extension_background_work_via_api(ext_id, user, access_token): - logger.info( - f"Stopping background work for extension '{ext_id}' using the REST API." - ) - async with httpx.AsyncClient() as client: - try: - url = f"http://{settings.host}:{settings.port}/{ext_id}/api/v1?usr={user}" - headers = ( - {"Authorization": "Bearer " + access_token} if access_token else None - ) - resp = await client.delete(url=url, headers=headers) - resp.raise_for_status() - logger.info(f"Stopped background work for extension '{ext_id}'.") - except Exception as ex: - logger.warning( - f"Failed to stop background work for '{ext_id}' using the REST API." - ) - logger.warning(ex) - - def to_valid_user_id(user_id: str) -> UUID: if len(user_id) < 32: raise ValueError("User ID must have at least 128 bits") @@ -161,7 +97,7 @@ async def migrate_databases(): await load_disabled_extension_list() # todo: revisit, use installed extensions - for ext in get_valid_extensions(False): + for ext in Extension.get_valid_extensions(False): current_version = current_versions.get(ext.code, 0) try: await migrate_extension_database(ext, current_version) diff --git a/lnbits/core/views/extension_api.py b/lnbits/core/views/extension_api.py index 2dacf65b2..71bd3215f 100644 --- a/lnbits/core/views/extension_api.py +++ b/lnbits/core/views/extension_api.py @@ -1,8 +1,6 @@ -import sys from http import HTTPStatus from typing import ( List, - Optional, ) from bolt11 import decode as bolt11_decode @@ -13,10 +11,21 @@ from fastapi import ( ) from loguru import logger -from lnbits.core.db import core_app_extra -from lnbits.core.helpers import ( - migrate_extension_database, - stop_extension_background_work, +from lnbits.core.extensions.extension_manager import ( + activate_extension, + deactivate_extension, + install_extension, + uninstall_extension, +) +from lnbits.core.extensions.models import ( + CreateExtension, + Extension, + ExtensionConfig, + ExtensionRelease, + InstallableExtension, + PayToEnableInfo, + ReleasePaymentInfo, + UserExtensionInfo, ) from lnbits.core.models import ( SimpleStatus, @@ -24,36 +33,18 @@ from lnbits.core.models import ( ) from lnbits.core.services import check_transaction_status, create_invoice from lnbits.decorators import ( - check_access_token, check_admin, check_user_exists, ) -from lnbits.extension_manager import ( - CreateExtension, - Extension, - ExtensionRelease, - InstallableExtension, - PayToEnableInfo, - ReleasePaymentInfo, - UserExtensionInfo, - fetch_github_release_config, - fetch_release_details, - fetch_release_payment_info, - get_valid_extensions, -) -from lnbits.settings import settings from ..crud import ( - add_installed_extension, delete_dbversion, - delete_installed_extension, drop_extension_db, get_dbversions, get_installed_extension, get_installed_extensions, get_user_extension, update_extension_pay_to_enable, - update_installed_extension_state, update_user_extension, update_user_extension_extra, ) @@ -64,12 +55,8 @@ extension_router = APIRouter( ) -@extension_router.post("") -async def api_install_extension( - data: CreateExtension, - user: User = Depends(check_admin), - access_token: Optional[str] = Depends(check_access_token), -): +@extension_router.post("", dependencies=[Depends(check_admin)]) +async def api_install_extension(data: CreateExtension): release = await InstallableExtension.get_extension_release( data.ext_id, data.source_repo, data.archive, data.version ) @@ -89,43 +76,36 @@ async def api_install_extension( ) try: - installed_ext = await get_installed_extension(data.ext_id) - ext_info.payments = installed_ext.payments if installed_ext else [] + extension = await install_extension(ext_info) - await ext_info.download_archive() - - ext_info.extract_archive() - - extension = Extension.from_installable_ext(ext_info) - - db_version = (await get_dbversions()).get(data.ext_id, 0) - await migrate_extension_database(extension, db_version) - - ext_info.active = True - await add_installed_extension(ext_info) - - if extension.is_upgrade_extension: - # call stop while the old routes are still active - await stop_extension_background_work(data.ext_id, user.id, access_token) - - # mount routes for the new version - core_app_extra.register_new_ext_routes(extension) - - ext_info.notify_upgrade(extension.upgrade_hash) - settings.lnbits_deactivated_extensions.discard(data.ext_id) - - return extension - except AssertionError as exc: - raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc except Exception as exc: logger.warning(exc) ext_info.clean_extension_files() + detail = ( + str(exc) + if isinstance(exc, AssertionError) + else f"Failed to install extension '{ext_info.id}'." + f"({ext_info.installed_version})." + ) raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=( - f"Failed to install extension {ext_info.id} " - f"({ext_info.installed_version})." - ), + detail=detail, + ) from exc + + try: + await activate_extension(extension) + return extension + except Exception as exc: + logger.warning(exc) + await deactivate_extension(extension.code) + detail = ( + str(exc) + if isinstance(exc, AssertionError) + else f"Extension `{extension.code}` installed, but activation failed." + ) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=detail, ) from exc @@ -143,7 +123,7 @@ async def api_extension_details( ) assert release, "Details not found for release" - release_details = await fetch_release_details(details_link) + release_details = await ExtensionRelease.fetch_release_details(details_link) assert release_details, "Cannot fetch details for release" release_details["icon"] = release.icon release_details["repo"] = release.repo @@ -186,7 +166,7 @@ async def api_update_pay_to_enable( async def api_enable_extension( ext_id: str, user: User = Depends(check_user_exists) ) -> SimpleStatus: - if ext_id not in [e.code for e in get_valid_extensions()]: + if ext_id not in [e.code for e in Extension.get_valid_extensions()]: raise HTTPException( HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' doesn't exist." ) @@ -249,7 +229,7 @@ async def api_enable_extension( async def api_disable_extension( ext_id: str, user: User = Depends(check_user_exists) ) -> SimpleStatus: - if ext_id not in [e.code for e in get_valid_extensions()]: + if ext_id not in [e.code for e in Extension.get_valid_extensions()]: raise HTTPException( HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist." ) @@ -270,20 +250,14 @@ async def api_activate_extension(ext_id: str) -> SimpleStatus: try: logger.info(f"Activating extension: '{ext_id}'.") - all_extensions = get_valid_extensions() - ext = next((e for e in all_extensions if e.code == ext_id), None) + ext = Extension.get_valid_extension(ext_id) assert ext, f"Extension '{ext_id}' doesn't exist." - # if extension never loaded (was deactivated on server startup) - if ext_id not in sys.modules.keys(): - # run extension start-up routine - core_app_extra.register_new_ext_routes(ext) - settings.lnbits_deactivated_extensions.discard(ext_id) - - await update_installed_extension_state(ext_id=ext_id, active=True) + await activate_extension(ext) return SimpleStatus(success=True, message=f"Extension '{ext_id}' activated.") except Exception as exc: logger.warning(exc) + await deactivate_extension(ext_id) raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=(f"Failed to activate '{ext_id}'."), @@ -295,13 +269,10 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus: try: logger.info(f"Deactivating extension: '{ext_id}'.") - all_extensions = get_valid_extensions() - ext = next((e for e in all_extensions if e.code == ext_id), None) + ext = Extension.get_valid_extension(ext_id) assert ext, f"Extension '{ext_id}' doesn't exist." - settings.lnbits_deactivated_extensions.add(ext_id) - - await update_installed_extension_state(ext_id=ext_id, active=False) + await deactivate_extension(ext_id) return SimpleStatus(success=True, message=f"Extension '{ext_id}' deactivated.") except Exception as exc: logger.warning(exc) @@ -311,23 +282,19 @@ async def api_deactivate_extension(ext_id: str) -> SimpleStatus: ) from exc -@extension_router.delete("/{ext_id}") -async def api_uninstall_extension( - ext_id: str, - user: User = Depends(check_admin), - access_token: Optional[str] = Depends(check_access_token), -) -> SimpleStatus: - installed_extensions = await get_installed_extensions() +@extension_router.delete("/{ext_id}", dependencies=[Depends(check_admin)]) +async def api_uninstall_extension(ext_id: str) -> SimpleStatus: - extensions = [e for e in installed_extensions if e.id == ext_id] - if len(extensions) == 0: + extension = await get_installed_extension(ext_id) + if not extension: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail=f"Unknown extension id: {ext_id}", ) + installed_extensions = await get_installed_extensions() # check that other extensions do not depend on this one - for valid_ext_id in [ext.code for ext in get_valid_extensions()]: + for valid_ext_id in [ext.code for ext in Extension.get_valid_extensions()]: installed_ext = next( (ext for ext in installed_extensions if ext.id == valid_ext_id), None ) @@ -341,14 +308,7 @@ async def api_uninstall_extension( ) try: - # call stop while the old routes are still active - await stop_extension_background_work(ext_id, user.id, access_token) - - settings.lnbits_deactivated_extensions.add(ext_id) - - for ext_info in extensions: - ext_info.clean_extension_files() - await delete_installed_extension(ext_id=ext_info.id) + await uninstall_extension(ext_id) logger.success(f"Extension '{ext_id}' uninstalled.") return SimpleStatus(success=True, message=f"Extension '{ext_id}' uninstalled.") @@ -397,9 +357,8 @@ async def get_pay_to_install_invoice( assert release, "Release not found." assert release.pay_link, "Pay link not found for release." - payment_info = await fetch_release_payment_info( - release.pay_link, data.cost_sats - ) + payment_info = await release.fetch_release_payment_info(data.cost_sats) + assert payment_info and payment_info.payment_request, "Cannot request invoice." invoice = bolt11_decode(payment_info.payment_request) @@ -474,7 +433,7 @@ async def get_pay_to_enable_invoice( ) async def get_extension_release(org: str, repo: str, tag_name: str): try: - config = await fetch_github_release_config(org, repo, tag_name) + config = await ExtensionConfig.fetch_github_release_config(org, repo, tag_name) if not config: return {} diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 1793dee9f..c06adba21 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -12,6 +12,7 @@ from lnurl import decode as lnurl_decode from loguru import logger from pydantic.types import UUID4 +from lnbits.core.extensions.models import Extension, InstallableExtension from lnbits.core.helpers import to_valid_user_id from lnbits.core.models import User from lnbits.core.services import create_invoice @@ -20,7 +21,6 @@ from lnbits.helpers import template_renderer from lnbits.settings import settings from lnbits.wallets import get_funding_source -from ...extension_manager import InstallableExtension, get_valid_extensions from ...utils.exchange_rates import allowed_currencies, currencies from ..crud import ( create_account, @@ -104,7 +104,7 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)): installed_exts_ids = [] try: - all_ext_ids = [ext.code for ext in get_valid_extensions()] + all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()] inactive_extensions = [ e.id for e in await get_installed_extensions(active=False) ] diff --git a/lnbits/helpers.py b/lnbits/helpers.py index d7545d4b0..6b4240598 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -10,6 +10,7 @@ import shortuuid from pydantic import BaseModel from pydantic.schema import field_schema +from lnbits.core.extensions.models import Extension from lnbits.db import get_placeholder from lnbits.jinja2_templating import Jinja2Templates from lnbits.nodes import get_node_class @@ -18,7 +19,6 @@ from lnbits.settings import settings from lnbits.utils.crypto import AESCipher from .db import FilterModel -from .extension_manager import get_valid_extensions def get_db_vendor_name(): @@ -93,7 +93,7 @@ def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templa settings.lnbits_node_ui and get_node_class() is not None ) t.env.globals["LNBITS_NODE_UI_AVAILABLE"] = get_node_class() is not None - t.env.globals["EXTENSIONS"] = get_valid_extensions(False) + t.env.globals["EXTENSIONS"] = Extension.get_valid_extensions(False) if settings.lnbits_custom_logo: t.env.globals["USE_CUSTOM_LOGO"] = settings.lnbits_custom_logo diff --git a/lnbits/middleware.py b/lnbits/middleware.py index 92e6d073e..611995a06 100644 --- a/lnbits/middleware.py +++ b/lnbits/middleware.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, List, Tuple, Union +from typing import Any, List, Union from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse @@ -45,16 +45,11 @@ class InstalledExtensionMiddleware: await self.app(scope, receive, send) return - upgrade_path = next( - ( - e - for e in settings.lnbits_upgraded_extensions - if e.endswith(f"/{top_path}") - ), - None, - ) # re-route all trafic if the extension has been upgraded - if upgrade_path: + if top_path in settings.lnbits_upgraded_extensions: + upgrade_path = ( + f"""{settings.lnbits_upgraded_extensions[top_path]}/{top_path}""" + ) tail = "/".join(rest) scope["path"] = f"/upgrades/{upgrade_path}/{tail}" @@ -118,72 +113,12 @@ class ExtensionsRedirectMiddleware: return req_headers = scope["headers"] if "headers" in scope else [] - redirect = self._find_redirect(scope["path"], req_headers) + redirect = settings.find_extension_redirect(scope["path"], req_headers) if redirect: - scope["path"] = self._new_path(redirect, scope["path"]) + scope["path"] = redirect.new_path_from(scope["path"]) await self.app(scope, receive, send) - def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]): - return next( - ( - r - for r in settings.lnbits_extensions_redirects - if self._redirect_matches(r, path, req_headers) - ), - None, - ) - - def _redirect_matches( - self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]] - ) -> bool: - if "from_path" not in redirect: - return False - header_filters = ( - redirect["header_filters"] if "header_filters" in redirect else {} - ) - return self._has_common_path(redirect["from_path"], path) and self._has_headers( - header_filters, req_headers - ) - - def _has_headers( - self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]] - ) -> bool: - for h in filter_headers: - if not self._has_header(req_headers, (str(h), str(filter_headers[h]))): - return False - return True - - def _has_header( - self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str] - ) -> bool: - for h in req_headers: - if ( - h[0].decode().lower() == header[0].lower() - and h[1].decode() == header[1] - ): - return True - return False - - def _has_common_path(self, redirect_path: str, req_path: str) -> bool: - redirect_path_elements = redirect_path.split("/") - req_path_elements = req_path.split("/") - if len(redirect_path) > len(req_path): - return False - sub_path = req_path_elements[: len(redirect_path_elements)] - return redirect_path == "/".join(sub_path) - - def _new_path(self, redirect: dict, req_path: str) -> str: - from_path = redirect["from_path"].split("/") - redirect_to = redirect["redirect_to_path"].split("/") - req_tail_path = req_path.split("/")[len(from_path) :] - - elements = [ - e for e in ([redirect["ext_id"], *redirect_to, *req_tail_path]) if e != "" - ] - - return "/" + "/".join(elements) - def add_ratelimit_middleware(app: FastAPI): core_app_extra.register_new_ratelimiter() diff --git a/lnbits/settings.py b/lnbits/settings.py index 6a5c2ce1c..91c2a1e05 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -62,26 +62,132 @@ class ExtensionsInstallSettings(LNbitsSettings): lnbits_ext_github_token: str = Field(default="") +class RedirectPath(BaseModel): + ext_id: str + from_path: str + redirect_to_path: str + header_filters: dict = {} + + def in_conflict(self, other: RedirectPath) -> bool: + if self.ext_id == other.ext_id: + return False + return self.redirect_matches( + other.from_path, list(other.header_filters.items()) + ) or other.redirect_matches(self.from_path, list(self.header_filters.items())) + + def find_in_conflict(self, others: list[RedirectPath]) -> Optional[RedirectPath]: + for other in others: + if self.in_conflict(other): + return other + return None + + def new_path_from(self, req_path: str) -> str: + from_path = self.from_path.split("/") + redirect_to = self.redirect_to_path.split("/") + req_tail_path = req_path.split("/")[len(from_path) :] + + elements = [e for e in ([self.ext_id, *redirect_to, *req_tail_path]) if e != ""] + + return "/" + "/".join(elements) + + def redirect_matches(self, path: str, req_headers: list[tuple[str, str]]) -> bool: + return self._has_common_path(path) and self._has_headers(req_headers) + + def _has_common_path(self, req_path: str) -> bool: + if len(self.from_path) > len(req_path): + return False + + redirect_path_elements = self.from_path.split("/") + req_path_elements = req_path.split("/") + + sub_path = req_path_elements[: len(redirect_path_elements)] + return self.from_path == "/".join(sub_path) + + def _has_headers(self, req_headers: list[tuple[str, str]]) -> bool: + for h in self.header_filters: + if not self._has_header(req_headers, (str(h), str(self.header_filters[h]))): + return False + return True + + def _has_header( + self, req_headers: list[tuple[str, str]], header: tuple[str, str] + ) -> bool: + for h in req_headers: + if h[0].lower() == header[0].lower() and h[1].lower() == header[1].lower(): + return True + return False + + class InstalledExtensionsSettings(LNbitsSettings): # installed extensions that have been deactivated lnbits_deactivated_extensions: set[str] = Field(default=[]) # upgraded extensions that require API redirects - lnbits_upgraded_extensions: set[str] = Field(default=[]) + lnbits_upgraded_extensions: dict[str, str] = Field(default={}) # list of redirects that extensions want to perform - lnbits_extensions_redirects: list[Any] = Field(default=[]) + lnbits_extensions_redirects: list[RedirectPath] = Field(default=[]) # list of all extension ids lnbits_all_extensions_ids: set[Any] = Field(default=[]) - def extension_upgrade_path(self, ext_id: str) -> Optional[str]: + def find_extension_redirect( + self, path: str, req_headers: list[tuple[bytes, bytes]] + ) -> Optional[RedirectPath]: + headers = [(k.decode(), v.decode()) for k, v in req_headers] return next( - (e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")), + ( + r + for r in self.lnbits_extensions_redirects + if r.redirect_matches(path, headers) + ), None, ) - def extension_upgrade_hash(self, ext_id: str) -> Optional[str]: - path = settings.extension_upgrade_path(ext_id) - return path.split("/")[0] if path else None + def activate_extension_paths( + self, + ext_id: str, + upgrade_hash: Optional[str] = None, + ext_redirects: Optional[list[dict]] = None, + ): + self.lnbits_deactivated_extensions.discard(ext_id) + + """ + Update the list of upgraded extensions. The middleware will perform + redirects based on this + """ + if upgrade_hash: + self.lnbits_upgraded_extensions[ext_id] = upgrade_hash + + if ext_redirects: + self._activate_extension_redirects(ext_id, ext_redirects) + + self.lnbits_all_extensions_ids.add(ext_id) + + def deactivate_extension_paths(self, ext_id: str): + self.lnbits_deactivated_extensions.add(ext_id) + self._remove_extension_redirects(ext_id) + + def _activate_extension_redirects(self, ext_id: str, ext_redirects: list[dict]): + ext_redirect_paths = [ + RedirectPath(**{"ext_id": ext_id, **er}) for er in ext_redirects + ] + existing_redirects = { + r.ext_id + for r in self.lnbits_extensions_redirects + if r.find_in_conflict(ext_redirect_paths) + } + + assert len(existing_redirects) == 0, ( + f"Cannot redirect for extension '{ext_id}'." + f" Already mapped by {existing_redirects}." + ) + + self._remove_extension_redirects(ext_id) + self.lnbits_extensions_redirects += ext_redirect_paths + + def _remove_extension_redirects(self, ext_id: str): + self.lnbits_extensions_redirects = [ + er for er in self.lnbits_extensions_redirects if er.ext_id != ext_id + ] class ThemesSettings(LNbitsSettings): diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py new file mode 100644 index 000000000..a300afb0f --- /dev/null +++ b/tests/unit/test_settings.py @@ -0,0 +1,168 @@ +import pytest + +from lnbits.settings import RedirectPath + +lnurlp_redirect_path = { + "from_path": "/.well-known/lnurlp", + "redirect_to_path": "/api/v1/well-known", +} +lnurlp_redirect_path_with_headers = { + "from_path": "/.well-known/lnurlp", + "redirect_to_path": "/api/v1/well-known", + "header_filters": {"accept": "application/nostr+json"}, +} + +lnaddress_redirect_path = { + "from_path": "/.well-known/lnurlp", + "redirect_to_path": "/api/v1/well-known", +} + +nostrrelay_redirect_path = { + "from_path": "/", + "redirect_to_path": "/api/v1/relay-info", + "header_filters": {"accept": "application/nostr+json"}, +} + + +@pytest.fixture() +def lnurlp(): + return RedirectPath(ext_id="lnurlp", **lnurlp_redirect_path) + + +@pytest.fixture() +def lnurlp_with_headers(): + return RedirectPath( + ext_id="lnurlp_with_headers", **lnurlp_redirect_path_with_headers + ) + + +@pytest.fixture() +def lnaddress(): + return RedirectPath(ext_id="lnaddress", **lnaddress_redirect_path) + + +@pytest.fixture() +def nostrrelay(): + return RedirectPath(ext_id="nostrrelay", **nostrrelay_redirect_path) + + +def test_redirect_path_self_not_in_conflict( + lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath +): + assert not lnurlp.in_conflict(lnurlp), "Path is not in conflict with itself." + assert not lnaddress.in_conflict(lnaddress), "Path is not in conflict with itself." + assert not nostrrelay.in_conflict( + nostrrelay + ), "Path is not in conflict with itself." + + assert not lnurlp.in_conflict(nostrrelay) + + assert not nostrrelay.in_conflict(lnurlp) + + +def test_redirect_path_not_in_conflict( + lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath +): + + assert not lnurlp.in_conflict(nostrrelay) + + assert not nostrrelay.in_conflict(lnurlp) + + assert not lnaddress.in_conflict(nostrrelay) + + assert not nostrrelay.in_conflict(lnaddress) + + +def test_redirect_path_in_conflict(lnurlp: RedirectPath, lnaddress: RedirectPath): + assert lnurlp.in_conflict(lnaddress) + assert lnaddress.in_conflict(lnurlp) + + +def test_redirect_path_find_conflict( + lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath +): + assert lnurlp.find_in_conflict([nostrrelay, lnaddress]) + assert lnurlp.find_in_conflict([lnaddress, nostrrelay]) + assert lnaddress.find_in_conflict([nostrrelay, lnurlp]) + assert lnaddress.find_in_conflict([lnurlp, nostrrelay]) + + +def test_redirect_path_find_no_conflict( + lnurlp: RedirectPath, lnaddress: RedirectPath, nostrrelay: RedirectPath +): + assert not nostrrelay.find_in_conflict([lnurlp, lnaddress]) + assert not lnurlp.find_in_conflict([nostrrelay]) + assert not lnaddress.find_in_conflict([nostrrelay]) + + +def test_redirect_path_in_conflict_with_headers( + lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath +): + assert lnurlp.in_conflict(lnurlp_with_headers) + assert lnurlp_with_headers.in_conflict(lnurlp) + + +def test_redirect_path_matches_with_headers( + lnurlp: RedirectPath, lnurlp_with_headers: RedirectPath +): + headers_list = list(lnurlp_with_headers.header_filters.items()) + assert lnurlp.redirect_matches( + path=lnurlp_with_headers.from_path, + req_headers=headers_list, + ) + assert lnurlp_with_headers.redirect_matches( + path=lnurlp_redirect_path["from_path"], + req_headers=[("ACCEPT", "APPlication/nostr+json")], + ) + assert lnurlp_with_headers.redirect_matches( + path=lnurlp_redirect_path["from_path"], + req_headers=[("accept", "application/nostr+json"), ("my_header", "my_value")], + ) + + assert not lnurlp_with_headers.redirect_matches( + path=lnurlp_redirect_path["from_path"], req_headers=[] + ) + assert not lnurlp_with_headers.redirect_matches( + path=lnurlp_redirect_path["from_path"], + req_headers=[("accept", "application/json")], + ) + assert not lnurlp_with_headers.redirect_matches(path="/random/path", req_headers=[]) + assert not lnurlp_with_headers.redirect_matches(path="/random_path", req_headers=[]) + assert not lnurlp_with_headers.redirect_matches( + path="/.well-known/lnurlp", req_headers=[] + ) + assert lnurlp.redirect_matches(path="/.well-known/lnurlp", req_headers=[]) + assert lnurlp.redirect_matches( + path="/.well-known/lnurlp/some/other/path", req_headers=[] + ) + assert lnurlp.redirect_matches( + path="/.well-known/lnurlp/some/other/path", + req_headers=headers_list, + ) + assert not lnurlp_with_headers.redirect_matches( + path="/.well-known/lnurlp", req_headers=[] + ) + assert not lnurlp_with_headers.redirect_matches( + path="/.well-known/lnurlp/some/other/path", req_headers=[] + ) + assert lnurlp_with_headers.redirect_matches( + path="/.well-known/lnurlp/some/other/path", + req_headers=headers_list, + ) + + +def test_redirect_path_new_path_from(lnurlp: RedirectPath): + assert lnurlp.new_path_from("") == "/lnurlp/api/v1/well-known" + assert lnurlp.new_path_from("/") == "/lnurlp/api/v1/well-known" + assert lnurlp.new_path_from("/path") == "/lnurlp/api/v1/well-known" + assert lnurlp.new_path_from("/path/more") == "/lnurlp/api/v1/well-known" + + assert lnurlp.new_path_from("/.well-known/lnurlp") == "/lnurlp/api/v1/well-known" + assert ( + lnurlp.new_path_from("/.well-known/lnurlp/path") + == "/lnurlp/api/v1/well-known/path" + ) + assert ( + lnurlp.new_path_from("/.well-known/lnurlp/path/more") + == "/lnurlp/api/v1/well-known/path/more" + )