diff --git a/lnbits/app.py b/lnbits/app.py index 92c4105f4..3c848c9d7 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -26,12 +26,10 @@ from .commands import migrate_databases from .core import core_app, core_app_extra from .core.services import check_admin_settings from .core.views.generic import core_html_routes +from .extensions import Extension, InstalledExtensionMiddleware, get_valid_extensions from .helpers import ( - Extension, - InstalledExtensionMiddleware, get_css_vendored, get_js_vendored, - get_valid_extensions, template_renderer, url_for_vendored, ) diff --git a/lnbits/commands.py b/lnbits/commands.py index 0de584c47..4f8ef2609 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -14,12 +14,8 @@ from .core import migrations as core_migrations from .core.crud import USER_ID_ALL, get_dbversions, get_inactive_extensions from .core.helpers import migrate_extension_database, run_migration from .db import COCKROACH, POSTGRES, SQLITE -from .helpers import ( - get_css_vendored, - get_js_vendored, - get_valid_extensions, - url_for_vendored, -) +from .extensions import get_valid_extensions +from .helpers import get_css_vendored, get_js_vendored, url_for_vendored @click.command("migrate") diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py index 4cd5edbcc..78d188d28 100644 --- a/lnbits/core/helpers.py +++ b/lnbits/core/helpers.py @@ -3,7 +3,7 @@ import re from loguru import logger -from lnbits.helpers import Extension +from lnbits.extensions import Extension from . import db as core_db from .crud import update_migration_version diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 649305e05..182467a49 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -44,12 +44,8 @@ from lnbits.decorators import ( require_admin_key, require_invoice_key, ) -from lnbits.helpers import ( - Extension, - InstallableExtension, - get_valid_extensions, - url_for, -) +from lnbits.extensions import Extension, InstallableExtension, get_valid_extensions +from lnbits.helpers import url_for from lnbits.settings import get_wallet_class, settings from lnbits.utils.exchange_rates import ( currencies, diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 1510df371..2c81fc31e 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -16,7 +16,7 @@ from lnbits.decorators import check_admin, check_user_exists from lnbits.helpers import template_renderer, url_for from lnbits.settings import get_wallet_class, settings -from ...helpers import InstallableExtension, get_valid_extensions +from ...extensions import InstallableExtension, get_valid_extensions from ..crud import ( USER_ID_ALL, create_account, diff --git a/lnbits/extensions.py b/lnbits/extensions.py new file mode 100644 index 000000000..9f45b1967 --- /dev/null +++ b/lnbits/extensions.py @@ -0,0 +1,277 @@ +import hashlib +import json +import os +import shutil +import sys +import urllib.request +import zipfile +from http import HTTPStatus +from typing import List, NamedTuple, Optional + +import httpx +from fastapi.exceptions import HTTPException +from fastapi.responses import JSONResponse +from loguru import logger +from starlette.types import ASGIApp, Receive, Scope, Send + +from lnbits.settings import settings + + +class Extension(NamedTuple): + code: str + is_valid: bool + is_admin_only: bool + name: Optional[str] = None + short_description: Optional[str] = None + tile: Optional[str] = None + contributors: Optional[List[str]] = None + hidden: bool = False + migration_module: Optional[str] = None + db_name: Optional[str] = None + hash: Optional[str] = "" + + @property + def module_name(self): + return ( + f"lnbits.extensions.{self.code}" + if self.hash == "" + else f"lnbits.upgrades.{self.code}-{self.hash}.{self.code}" + ) + + +class ExtensionManager: + def __init__(self, include_disabled_exts=False): + self._disabled: List[str] = ( + [] if include_disabled_exts else settings.lnbits_disabled_extensions + ) + self._admin_only: List[str] = settings.lnbits_admin_extensions + self._extension_folders: List[str] = [ + x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions")) + ][0] + + @property + def extensions(self) -> List[Extension]: + output: List[Extension] = [] + + if "all" in self._disabled: + return output + + for extension in [ + ext for ext in self._extension_folders if ext not in self._disabled + ]: + try: + with open( + os.path.join( + settings.lnbits_path, "extensions", extension, "config.json" + ) + ) as json_file: + config = json.load(json_file) + is_valid = True + is_admin_only = True if extension in self._admin_only else False + except Exception: + config = {} + is_valid = False + is_admin_only = False + + output.append( + Extension( + extension, + is_valid, + is_admin_only, + config.get("name"), + config.get("short_description"), + config.get("tile"), + config.get("contributors"), + config.get("hidden") or False, + config.get("migration_module"), + config.get("db_name"), + ) + ) + + return output + + +class InstallableExtension(NamedTuple): + id: str + name: str + archive: str + hash: str + short_description: Optional[str] = None + details: Optional[str] = None + icon: Optional[str] = None + dependencies: List[str] = [] + is_admin_only: bool = False + version: Optional[int] = 0 + + @property + def zip_path(self) -> str: + extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") + os.makedirs(extensions_data_dir, exist_ok=True) + return os.path.join(extensions_data_dir, f"{self.id}.zip") + + @property + def ext_dir(self) -> str: + return os.path.join("lnbits", "extensions", self.id) + + @property + def module_name(self) -> str: + return f"lnbits.extensions.{self.id}" + + @property + def module_installed(self) -> bool: + return self.module_name in sys.modules + + def download_archive(self): + ext_zip_file = self.zip_path + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + try: + download_url(self.archive, ext_zip_file) + except Exception as ex: + logger.warning(ex) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch extension archive file", + ) + + archive_hash = file_hash(ext_zip_file) + if self.hash != archive_hash: + # remove downloaded archive + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="File hash missmatch. Will not install.", + ) + + def extract_archive(self): + shutil.rmtree(self.ext_dir, True) + with zipfile.ZipFile(self.zip_path, "r") as zip_ref: + zip_ref.extractall(os.path.join("lnbits", "extensions")) + + ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}") + os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True) + shutil.rmtree(ext_upgrade_dir, True) + with zipfile.ZipFile(self.zip_path, "r") as zip_ref: + zip_ref.extractall(ext_upgrade_dir) + + @classmethod + async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension": + installable_extensions: List[ + InstallableExtension + ] = await InstallableExtension.get_installable_extensions() + + valid_extensions = [ + e for e in installable_extensions if e.id == ext_id and e.hash == hash + ] + if len(valid_extensions) == 0: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Unknown extension id: {ext_id}", + ) + extension = valid_extensions[0] + + # check that all dependecies are installed + installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) + if not set(extension.dependencies).issubset(installed_extensions): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"Not all dependencies are installed: {extension.dependencies}", + ) + + return extension + + @classmethod + async def get_installable_extensions(cls) -> List["InstallableExtension"]: + extension_list: List[InstallableExtension] = [] + + async with httpx.AsyncClient() as client: + for url in settings.lnbits_extensions_manifests: + resp = await client.get(url) + if resp.status_code != 200: + raise HTTPException( + status_code=404, + detail=f"Unable to fetch extension list for repository: {url}", + ) + for e in resp.json()["extensions"]: + extension_list += [ + InstallableExtension( + id=e["id"], + name=e["name"], + archive=e["archive"], + hash=e["hash"], + short_description=e["shortDescription"], + details=e["details"] if "details" in e else "", + icon=e["icon"], + dependencies=e["dependencies"] + if "dependencies" in e + else [], + ) + ] + + return extension_list + + +class InstalledExtensionMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if not "path" in scope: + await self.app(scope, receive, send) + return + + path_elements = scope["path"].split("/") + if len(path_elements) > 2: + _, path_name, path_type, *rest = path_elements + else: + _, path_name = path_elements + path_type = None + + # block path for all users if the extension is disabled + if path_name in settings.lnbits_disabled_extensions: + response = JSONResponse( + status_code=HTTPStatus.NOT_FOUND, + content={"detail": f"Extension '{path_name}' disabled"}, + ) + await response(scope, receive, send) + return + + # re-route API trafic if the extension has been upgraded + if path_type == "api": + upgraded_extensions = list( + filter( + lambda ext: ext.endswith(f"/{path_name}"), + settings.lnbits_upgraded_extensions, + ) + ) + if len(upgraded_extensions) != 0: + upgrade_path = upgraded_extensions[0] + tail = "/".join(rest) + scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}" + + await self.app(scope, receive, send) + + +def get_valid_extensions(include_disabled_exts=False) -> List[Extension]: + return [ + extension + for extension in ExtensionManager(include_disabled_exts).extensions + if extension.is_valid + ] + + +def download_url(url, save_path): + with urllib.request.urlopen(url) as dl_file: + with open(save_path, "wb") as out_file: + out_file.write(dl_file.read()) + + +def file_hash(filename): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() diff --git a/lnbits/helpers.py b/lnbits/helpers.py index d8d4bd3d9..edbd908cd 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -1,270 +1,15 @@ import glob -import hashlib -import json import os -import shutil -import sys -import urllib.request -import zipfile -from http import HTTPStatus -from typing import Any, List, NamedTuple, Optional +from typing import Any, List, Optional -import httpx import jinja2 import shortuuid # type: ignore -from fastapi.exceptions import HTTPException -from fastapi.responses import JSONResponse -from loguru import logger -from starlette.types import ASGIApp, Receive, Scope, Send from lnbits.jinja2_templating import Jinja2Templates from lnbits.requestvars import g from lnbits.settings import settings - -class Extension(NamedTuple): - code: str - is_valid: bool - is_admin_only: bool - name: Optional[str] = None - short_description: Optional[str] = None - tile: Optional[str] = None - contributors: Optional[List[str]] = None - hidden: bool = False - migration_module: Optional[str] = None - db_name: Optional[str] = None - hash: Optional[str] = "" - - @property - def module_name(self): - return ( - f"lnbits.extensions.{self.code}" - if self.hash == "" - else f"lnbits.upgrades.{self.code}-{self.hash}.{self.code}" - ) - - -class InstallableExtension(NamedTuple): - id: str - name: str - archive: str - hash: str - short_description: Optional[str] = None - details: Optional[str] = None - icon: Optional[str] = None - dependencies: List[str] = [] - is_admin_only: bool = False - version: Optional[int] = 0 - - @property - def zip_path(self) -> str: - extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") - os.makedirs(extensions_data_dir, exist_ok=True) - return os.path.join(extensions_data_dir, f"{self.id}.zip") - - @property - def ext_dir(self) -> str: - return os.path.join("lnbits", "extensions", self.id) - - @property - def module_name(self) -> str: - return f"lnbits.extensions.{self.id}" - - @property - def module_installed(self) -> bool: - return self.module_name in sys.modules - - def download_archive(self): - ext_zip_file = self.zip_path - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) - try: - download_url(self.archive, ext_zip_file) - except Exception as ex: - logger.warning(ex) - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="Cannot fetch extension archive file", - ) - - archive_hash = file_hash(ext_zip_file) - if self.hash != archive_hash: - # remove downloaded archive - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="File hash missmatch. Will not install.", - ) - - def extract_archive(self): - shutil.rmtree(self.ext_dir, True) - with zipfile.ZipFile(self.zip_path, "r") as zip_ref: - zip_ref.extractall(os.path.join("lnbits", "extensions")) - - ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}") - os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True) - shutil.rmtree(ext_upgrade_dir, True) - with zipfile.ZipFile(self.zip_path, "r") as zip_ref: - zip_ref.extractall(ext_upgrade_dir) - - @classmethod - async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension": - installable_extensions: List[ - InstallableExtension - ] = await InstallableExtension.get_installable_extensions() - - valid_extensions = [ - e for e in installable_extensions if e.id == ext_id and e.hash == hash - ] - if len(valid_extensions) == 0: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Unknown extension id: {ext_id}", - ) - extension = valid_extensions[0] - - # check that all dependecies are installed - installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) - if not set(extension.dependencies).issubset(installed_extensions): - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail=f"Not all dependencies are installed: {extension.dependencies}", - ) - - return extension - - @classmethod - async def get_installable_extensions(cls) -> List["InstallableExtension"]: - extension_list: List[InstallableExtension] = [] - - async with httpx.AsyncClient() as client: - for url in settings.lnbits_extensions_manifests: - resp = await client.get(url) - if resp.status_code != 200: - raise HTTPException( - status_code=404, - detail=f"Unable to fetch extension list for repository: {url}", - ) - for e in resp.json()["extensions"]: - extension_list += [ - InstallableExtension( - id=e["id"], - name=e["name"], - archive=e["archive"], - hash=e["hash"], - short_description=e["shortDescription"], - details=e["details"] if "details" in e else "", - icon=e["icon"], - dependencies=e["dependencies"] - if "dependencies" in e - else [], - ) - ] - - return extension_list - - -class ExtensionManager: - def __init__(self, include_disabled_exts=False): - self._disabled: List[str] = ( - [] if include_disabled_exts else settings.lnbits_disabled_extensions - ) - self._admin_only: List[str] = settings.lnbits_admin_extensions - self._extension_folders: List[str] = [ - x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions")) - ][0] - - @property - def extensions(self) -> List[Extension]: - output: List[Extension] = [] - - if "all" in self._disabled: - return output - - for extension in [ - ext for ext in self._extension_folders if ext not in self._disabled - ]: - try: - with open( - os.path.join( - settings.lnbits_path, "extensions", extension, "config.json" - ) - ) as json_file: - config = json.load(json_file) - is_valid = True - is_admin_only = True if extension in self._admin_only else False - except Exception: - config = {} - is_valid = False - is_admin_only = False - - output.append( - Extension( - extension, - is_valid, - is_admin_only, - config.get("name"), - config.get("short_description"), - config.get("tile"), - config.get("contributors"), - config.get("hidden") or False, - config.get("migration_module"), - config.get("db_name"), - ) - ) - - return output - - -class InstalledExtensionMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if not "path" in scope: - await self.app(scope, receive, send) - return - - path_elements = scope["path"].split("/") - if len(path_elements) > 2: - _, path_name, path_type, *rest = path_elements - else: - _, path_name = path_elements - path_type = None - - # block path for all users if the extension is disabled - if path_name in settings.lnbits_disabled_extensions: - response = JSONResponse( - status_code=HTTPStatus.NOT_FOUND, - content={"detail": f"Extension '{path_name}' disabled"}, - ) - await response(scope, receive, send) - return - - # re-route API trafic if the extension has been upgraded - if path_type == "api": - upgraded_extensions = list( - filter( - lambda ext: ext.endswith(f"/{path_name}"), - settings.lnbits_upgraded_extensions, - ) - ) - if len(upgraded_extensions) != 0: - upgrade_path = upgraded_extensions[0] - tail = "/".join(rest) - scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}" - - await self.app(scope, receive, send) - - -def get_valid_extensions(include_disabled_exts=False) -> List[Extension]: - return [ - extension - for extension in ExtensionManager(include_disabled_exts).extensions - if extension.is_valid - ] - +from .extensions import get_valid_extensions def urlsafe_short_hash() -> str: return shortuuid.uuid() @@ -397,19 +142,3 @@ def get_current_extension_name() -> str: except: ext_name = extension_director_name return ext_name - - -def download_url(url, save_path): - with urllib.request.urlopen(url) as dl_file: - with open(save_path, "wb") as out_file: - out_file.write(dl_file.read()) - - -def file_hash(filename): - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(filename, "rb", buffering=0) as f: - while n := f.readinto(mv): - h.update(mv[:n]) - return h.hexdigest()