mirror of
https://github.com/lnbits/lnbits.git
synced 2025-09-27 12:26:19 +02:00
refactor: extract extensions.py
This commit is contained in:
@@ -26,12 +26,10 @@ from .commands import migrate_databases
|
||||
from .core import core_app, core_app_extra
|
||||
from .core.services import check_admin_settings
|
||||
from .core.views.generic import core_html_routes
|
||||
from .extensions import Extension, InstalledExtensionMiddleware, get_valid_extensions
|
||||
from .helpers import (
|
||||
Extension,
|
||||
InstalledExtensionMiddleware,
|
||||
get_css_vendored,
|
||||
get_js_vendored,
|
||||
get_valid_extensions,
|
||||
template_renderer,
|
||||
url_for_vendored,
|
||||
)
|
||||
|
@@ -14,12 +14,8 @@ from .core import migrations as core_migrations
|
||||
from .core.crud import USER_ID_ALL, get_dbversions, get_inactive_extensions
|
||||
from .core.helpers import migrate_extension_database, run_migration
|
||||
from .db import COCKROACH, POSTGRES, SQLITE
|
||||
from .helpers import (
|
||||
get_css_vendored,
|
||||
get_js_vendored,
|
||||
get_valid_extensions,
|
||||
url_for_vendored,
|
||||
)
|
||||
from .extensions import get_valid_extensions
|
||||
from .helpers import get_css_vendored, get_js_vendored, url_for_vendored
|
||||
|
||||
|
||||
@click.command("migrate")
|
||||
|
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from lnbits.helpers import Extension
|
||||
from lnbits.extensions import Extension
|
||||
|
||||
from . import db as core_db
|
||||
from .crud import update_migration_version
|
||||
|
@@ -44,12 +44,8 @@ from lnbits.decorators import (
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
)
|
||||
from lnbits.helpers import (
|
||||
Extension,
|
||||
InstallableExtension,
|
||||
get_valid_extensions,
|
||||
url_for,
|
||||
)
|
||||
from lnbits.extensions import Extension, InstallableExtension, get_valid_extensions
|
||||
from lnbits.helpers import url_for
|
||||
from lnbits.settings import get_wallet_class, settings
|
||||
from lnbits.utils.exchange_rates import (
|
||||
currencies,
|
||||
|
@@ -16,7 +16,7 @@ from lnbits.decorators import check_admin, check_user_exists
|
||||
from lnbits.helpers import template_renderer, url_for
|
||||
from lnbits.settings import get_wallet_class, settings
|
||||
|
||||
from ...helpers import InstallableExtension, get_valid_extensions
|
||||
from ...extensions import InstallableExtension, get_valid_extensions
|
||||
from ..crud import (
|
||||
USER_ID_ALL,
|
||||
create_account,
|
||||
|
277
lnbits/extensions.py
Normal file
277
lnbits/extensions.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from http import HTTPStatus
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
||||
class Extension(NamedTuple):
|
||||
code: str
|
||||
is_valid: bool
|
||||
is_admin_only: bool
|
||||
name: Optional[str] = None
|
||||
short_description: Optional[str] = None
|
||||
tile: Optional[str] = None
|
||||
contributors: Optional[List[str]] = None
|
||||
hidden: bool = False
|
||||
migration_module: Optional[str] = None
|
||||
db_name: Optional[str] = None
|
||||
hash: Optional[str] = ""
|
||||
|
||||
@property
|
||||
def module_name(self):
|
||||
return (
|
||||
f"lnbits.extensions.{self.code}"
|
||||
if self.hash == ""
|
||||
else f"lnbits.upgrades.{self.code}-{self.hash}.{self.code}"
|
||||
)
|
||||
|
||||
|
||||
class ExtensionManager:
|
||||
def __init__(self, include_disabled_exts=False):
|
||||
self._disabled: List[str] = (
|
||||
[] if include_disabled_exts else settings.lnbits_disabled_extensions
|
||||
)
|
||||
self._admin_only: List[str] = settings.lnbits_admin_extensions
|
||||
self._extension_folders: List[str] = [
|
||||
x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions"))
|
||||
][0]
|
||||
|
||||
@property
|
||||
def extensions(self) -> List[Extension]:
|
||||
output: List[Extension] = []
|
||||
|
||||
if "all" in self._disabled:
|
||||
return output
|
||||
|
||||
for extension in [
|
||||
ext for ext in self._extension_folders if ext not in self._disabled
|
||||
]:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(
|
||||
settings.lnbits_path, "extensions", extension, "config.json"
|
||||
)
|
||||
) as json_file:
|
||||
config = json.load(json_file)
|
||||
is_valid = True
|
||||
is_admin_only = True if extension in self._admin_only else False
|
||||
except Exception:
|
||||
config = {}
|
||||
is_valid = False
|
||||
is_admin_only = False
|
||||
|
||||
output.append(
|
||||
Extension(
|
||||
extension,
|
||||
is_valid,
|
||||
is_admin_only,
|
||||
config.get("name"),
|
||||
config.get("short_description"),
|
||||
config.get("tile"),
|
||||
config.get("contributors"),
|
||||
config.get("hidden") or False,
|
||||
config.get("migration_module"),
|
||||
config.get("db_name"),
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class InstallableExtension(NamedTuple):
|
||||
id: str
|
||||
name: str
|
||||
archive: str
|
||||
hash: str
|
||||
short_description: Optional[str] = None
|
||||
details: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
dependencies: List[str] = []
|
||||
is_admin_only: bool = False
|
||||
version: Optional[int] = 0
|
||||
|
||||
@property
|
||||
def zip_path(self) -> str:
|
||||
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
|
||||
os.makedirs(extensions_data_dir, exist_ok=True)
|
||||
return os.path.join(extensions_data_dir, f"{self.id}.zip")
|
||||
|
||||
@property
|
||||
def ext_dir(self) -> str:
|
||||
return os.path.join("lnbits", "extensions", self.id)
|
||||
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
return f"lnbits.extensions.{self.id}"
|
||||
|
||||
@property
|
||||
def module_installed(self) -> bool:
|
||||
return self.module_name in sys.modules
|
||||
|
||||
def download_archive(self):
|
||||
ext_zip_file = self.zip_path
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
try:
|
||||
download_url(self.archive, ext_zip_file)
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="Cannot fetch extension archive file",
|
||||
)
|
||||
|
||||
archive_hash = file_hash(ext_zip_file)
|
||||
if self.hash != archive_hash:
|
||||
# remove downloaded archive
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="File hash missmatch. Will not install.",
|
||||
)
|
||||
|
||||
def extract_archive(self):
|
||||
shutil.rmtree(self.ext_dir, True)
|
||||
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(os.path.join("lnbits", "extensions"))
|
||||
|
||||
ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}")
|
||||
os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True)
|
||||
shutil.rmtree(ext_upgrade_dir, True)
|
||||
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(ext_upgrade_dir)
|
||||
|
||||
@classmethod
|
||||
async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension":
|
||||
installable_extensions: List[
|
||||
InstallableExtension
|
||||
] = await InstallableExtension.get_installable_extensions()
|
||||
|
||||
valid_extensions = [
|
||||
e for e in installable_extensions if e.id == ext_id and e.hash == hash
|
||||
]
|
||||
if len(valid_extensions) == 0:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
detail=f"Unknown extension id: {ext_id}",
|
||||
)
|
||||
extension = valid_extensions[0]
|
||||
|
||||
# check that all dependecies are installed
|
||||
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
|
||||
if not set(extension.dependencies).issubset(installed_extensions):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail=f"Not all dependencies are installed: {extension.dependencies}",
|
||||
)
|
||||
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
async def get_installable_extensions(cls) -> List["InstallableExtension"]:
|
||||
extension_list: List[InstallableExtension] = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
for url in settings.lnbits_extensions_manifests:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unable to fetch extension list for repository: {url}",
|
||||
)
|
||||
for e in resp.json()["extensions"]:
|
||||
extension_list += [
|
||||
InstallableExtension(
|
||||
id=e["id"],
|
||||
name=e["name"],
|
||||
archive=e["archive"],
|
||||
hash=e["hash"],
|
||||
short_description=e["shortDescription"],
|
||||
details=e["details"] if "details" in e else "",
|
||||
icon=e["icon"],
|
||||
dependencies=e["dependencies"]
|
||||
if "dependencies" in e
|
||||
else [],
|
||||
)
|
||||
]
|
||||
|
||||
return extension_list
|
||||
|
||||
|
||||
class InstalledExtensionMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if not "path" in scope:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path_elements = scope["path"].split("/")
|
||||
if len(path_elements) > 2:
|
||||
_, path_name, path_type, *rest = path_elements
|
||||
else:
|
||||
_, path_name = path_elements
|
||||
path_type = None
|
||||
|
||||
# block path for all users if the extension is disabled
|
||||
if path_name in settings.lnbits_disabled_extensions:
|
||||
response = JSONResponse(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
content={"detail": f"Extension '{path_name}' disabled"},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# re-route API trafic if the extension has been upgraded
|
||||
if path_type == "api":
|
||||
upgraded_extensions = list(
|
||||
filter(
|
||||
lambda ext: ext.endswith(f"/{path_name}"),
|
||||
settings.lnbits_upgraded_extensions,
|
||||
)
|
||||
)
|
||||
if len(upgraded_extensions) != 0:
|
||||
upgrade_path = upgraded_extensions[0]
|
||||
tail = "/".join(rest)
|
||||
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def get_valid_extensions(include_disabled_exts=False) -> List[Extension]:
|
||||
return [
|
||||
extension
|
||||
for extension in ExtensionManager(include_disabled_exts).extensions
|
||||
if extension.is_valid
|
||||
]
|
||||
|
||||
|
||||
def download_url(url, save_path):
|
||||
with urllib.request.urlopen(url) as dl_file:
|
||||
with open(save_path, "wb") as out_file:
|
||||
out_file.write(dl_file.read())
|
||||
|
||||
|
||||
def file_hash(filename):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(filename, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
return h.hexdigest()
|
@@ -1,270 +1,15 @@
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from http import HTTPStatus
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
import jinja2
|
||||
import shortuuid # type: ignore
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from lnbits.jinja2_templating import Jinja2Templates
|
||||
from lnbits.requestvars import g
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
||||
class Extension(NamedTuple):
|
||||
code: str
|
||||
is_valid: bool
|
||||
is_admin_only: bool
|
||||
name: Optional[str] = None
|
||||
short_description: Optional[str] = None
|
||||
tile: Optional[str] = None
|
||||
contributors: Optional[List[str]] = None
|
||||
hidden: bool = False
|
||||
migration_module: Optional[str] = None
|
||||
db_name: Optional[str] = None
|
||||
hash: Optional[str] = ""
|
||||
|
||||
@property
|
||||
def module_name(self):
|
||||
return (
|
||||
f"lnbits.extensions.{self.code}"
|
||||
if self.hash == ""
|
||||
else f"lnbits.upgrades.{self.code}-{self.hash}.{self.code}"
|
||||
)
|
||||
|
||||
|
||||
class InstallableExtension(NamedTuple):
|
||||
id: str
|
||||
name: str
|
||||
archive: str
|
||||
hash: str
|
||||
short_description: Optional[str] = None
|
||||
details: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
dependencies: List[str] = []
|
||||
is_admin_only: bool = False
|
||||
version: Optional[int] = 0
|
||||
|
||||
@property
|
||||
def zip_path(self) -> str:
|
||||
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
|
||||
os.makedirs(extensions_data_dir, exist_ok=True)
|
||||
return os.path.join(extensions_data_dir, f"{self.id}.zip")
|
||||
|
||||
@property
|
||||
def ext_dir(self) -> str:
|
||||
return os.path.join("lnbits", "extensions", self.id)
|
||||
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
return f"lnbits.extensions.{self.id}"
|
||||
|
||||
@property
|
||||
def module_installed(self) -> bool:
|
||||
return self.module_name in sys.modules
|
||||
|
||||
def download_archive(self):
|
||||
ext_zip_file = self.zip_path
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
try:
|
||||
download_url(self.archive, ext_zip_file)
|
||||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="Cannot fetch extension archive file",
|
||||
)
|
||||
|
||||
archive_hash = file_hash(ext_zip_file)
|
||||
if self.hash != archive_hash:
|
||||
# remove downloaded archive
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="File hash missmatch. Will not install.",
|
||||
)
|
||||
|
||||
def extract_archive(self):
|
||||
shutil.rmtree(self.ext_dir, True)
|
||||
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(os.path.join("lnbits", "extensions"))
|
||||
|
||||
ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}")
|
||||
os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True)
|
||||
shutil.rmtree(ext_upgrade_dir, True)
|
||||
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(ext_upgrade_dir)
|
||||
|
||||
@classmethod
|
||||
async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension":
|
||||
installable_extensions: List[
|
||||
InstallableExtension
|
||||
] = await InstallableExtension.get_installable_extensions()
|
||||
|
||||
valid_extensions = [
|
||||
e for e in installable_extensions if e.id == ext_id and e.hash == hash
|
||||
]
|
||||
if len(valid_extensions) == 0:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
detail=f"Unknown extension id: {ext_id}",
|
||||
)
|
||||
extension = valid_extensions[0]
|
||||
|
||||
# check that all dependecies are installed
|
||||
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
|
||||
if not set(extension.dependencies).issubset(installed_extensions):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail=f"Not all dependencies are installed: {extension.dependencies}",
|
||||
)
|
||||
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
async def get_installable_extensions(cls) -> List["InstallableExtension"]:
|
||||
extension_list: List[InstallableExtension] = []
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
for url in settings.lnbits_extensions_manifests:
|
||||
resp = await client.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unable to fetch extension list for repository: {url}",
|
||||
)
|
||||
for e in resp.json()["extensions"]:
|
||||
extension_list += [
|
||||
InstallableExtension(
|
||||
id=e["id"],
|
||||
name=e["name"],
|
||||
archive=e["archive"],
|
||||
hash=e["hash"],
|
||||
short_description=e["shortDescription"],
|
||||
details=e["details"] if "details" in e else "",
|
||||
icon=e["icon"],
|
||||
dependencies=e["dependencies"]
|
||||
if "dependencies" in e
|
||||
else [],
|
||||
)
|
||||
]
|
||||
|
||||
return extension_list
|
||||
|
||||
|
||||
class ExtensionManager:
|
||||
def __init__(self, include_disabled_exts=False):
|
||||
self._disabled: List[str] = (
|
||||
[] if include_disabled_exts else settings.lnbits_disabled_extensions
|
||||
)
|
||||
self._admin_only: List[str] = settings.lnbits_admin_extensions
|
||||
self._extension_folders: List[str] = [
|
||||
x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions"))
|
||||
][0]
|
||||
|
||||
@property
|
||||
def extensions(self) -> List[Extension]:
|
||||
output: List[Extension] = []
|
||||
|
||||
if "all" in self._disabled:
|
||||
return output
|
||||
|
||||
for extension in [
|
||||
ext for ext in self._extension_folders if ext not in self._disabled
|
||||
]:
|
||||
try:
|
||||
with open(
|
||||
os.path.join(
|
||||
settings.lnbits_path, "extensions", extension, "config.json"
|
||||
)
|
||||
) as json_file:
|
||||
config = json.load(json_file)
|
||||
is_valid = True
|
||||
is_admin_only = True if extension in self._admin_only else False
|
||||
except Exception:
|
||||
config = {}
|
||||
is_valid = False
|
||||
is_admin_only = False
|
||||
|
||||
output.append(
|
||||
Extension(
|
||||
extension,
|
||||
is_valid,
|
||||
is_admin_only,
|
||||
config.get("name"),
|
||||
config.get("short_description"),
|
||||
config.get("tile"),
|
||||
config.get("contributors"),
|
||||
config.get("hidden") or False,
|
||||
config.get("migration_module"),
|
||||
config.get("db_name"),
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class InstalledExtensionMiddleware:
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if not "path" in scope:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path_elements = scope["path"].split("/")
|
||||
if len(path_elements) > 2:
|
||||
_, path_name, path_type, *rest = path_elements
|
||||
else:
|
||||
_, path_name = path_elements
|
||||
path_type = None
|
||||
|
||||
# block path for all users if the extension is disabled
|
||||
if path_name in settings.lnbits_disabled_extensions:
|
||||
response = JSONResponse(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
content={"detail": f"Extension '{path_name}' disabled"},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# re-route API trafic if the extension has been upgraded
|
||||
if path_type == "api":
|
||||
upgraded_extensions = list(
|
||||
filter(
|
||||
lambda ext: ext.endswith(f"/{path_name}"),
|
||||
settings.lnbits_upgraded_extensions,
|
||||
)
|
||||
)
|
||||
if len(upgraded_extensions) != 0:
|
||||
upgrade_path = upgraded_extensions[0]
|
||||
tail = "/".join(rest)
|
||||
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def get_valid_extensions(include_disabled_exts=False) -> List[Extension]:
|
||||
return [
|
||||
extension
|
||||
for extension in ExtensionManager(include_disabled_exts).extensions
|
||||
if extension.is_valid
|
||||
]
|
||||
|
||||
from .extensions import get_valid_extensions
|
||||
|
||||
def urlsafe_short_hash() -> str:
|
||||
return shortuuid.uuid()
|
||||
@@ -397,19 +142,3 @@ def get_current_extension_name() -> str:
|
||||
except:
|
||||
ext_name = extension_director_name
|
||||
return ext_name
|
||||
|
||||
|
||||
def download_url(url, save_path):
|
||||
with urllib.request.urlopen(url) as dl_file:
|
||||
with open(save_path, "wb") as out_file:
|
||||
out_file.write(dl_file.read())
|
||||
|
||||
|
||||
def file_hash(filename):
|
||||
h = hashlib.sha256()
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
with open(filename, "rb", buffering=0) as f:
|
||||
while n := f.readinto(mv):
|
||||
h.update(mv[:n])
|
||||
return h.hexdigest()
|
||||
|
Reference in New Issue
Block a user