extension manager

This commit is contained in:
dni ⚡ 2024-09-27 09:38:45 +02:00 committed by Vlad Stan
parent 2ec7e3d23e
commit 64587b6222
9 changed files with 157 additions and 158 deletions

View File

@ -6,7 +6,7 @@ import shutil
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Callable, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -17,12 +17,10 @@ from slowapi.util import get_remote_address
from starlette.middleware.sessions import SessionMiddleware
from lnbits.core.crud import (
add_installed_extension,
get_dbversions,
get_installed_extensions,
update_installed_extension_state,
update_installed_extension,
)
from lnbits.core.extensions.extension_manager import deactivate_extension
from lnbits.core.extensions.helpers import version_parse
from lnbits.core.helpers import migrate_extension_database
from lnbits.core.tasks import ( # watchdog_task
@ -47,7 +45,7 @@ from lnbits.wallets import get_funding_source, set_funding_source
from .commands import migrate_databases
from .core import init_core_routers
from .core.db import core_app_extra
from .core.extensions.models import Extension, InstallableExtension
from .core.extensions.models import Extension, ExtensionMeta, InstallableExtension
from .core.services import check_admin_settings, check_webpush_settings
from .middleware import (
CustomGZipMiddleware,
@ -240,7 +238,9 @@ async def check_installed_extensions(app: FastAPI):
)
except Exception as e:
logger.warning(e)
await deactivate_extension(ext.id)
settings.deactivate_extension_paths(ext.id)
ext.active = False
await update_installed_extension(ext)
logger.warning(
f"Failed to re-install extension: {ext.id} ({ext.installed_version})"
)
@ -252,7 +252,7 @@ async def check_installed_extensions(app: FastAPI):
async def build_all_installed_extensions_list(
include_deactivated: Optional[bool] = True,
) -> List[InstallableExtension]:
) -> list[InstallableExtension]:
"""
Returns a list of all the installed extensions plus the extensions that
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
@ -272,8 +272,13 @@ async def build_all_installed_extensions_list(
release = next((e for e in ext_releases if e.is_version_compatible), None)
if release:
ext_meta = ExtensionMeta(installed_release=release)
ext_info = InstallableExtension(
id=ext_id, name=ext_id, installed_release=release, icon=release.icon
id=ext_id,
name=ext_id,
version=release.version,
icon=release.icon,
meta=ext_meta,
)
installed_extensions.append(ext_info)
@ -304,8 +309,8 @@ async def check_installed_extension_files(ext: InstallableExtension) -> bool:
async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
await add_installed_extension(ext)
await update_installed_extension_state(ext_id=ext.id, active=True)
ext.active = True
await update_installed_extension(ext)
extension = Extension.from_installable_ext(ext)
register_ext_routes(app, extension)

View File

@ -3,7 +3,7 @@ import importlib
import time
from functools import wraps
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Optional
import click
import httpx
@ -231,7 +231,7 @@ async def check_invalid_payments(
click.echo("Funding source: " + str(funding_source))
# payments that are settled in the DB, but not at the Funding source level
invalid_payments: List[Payment] = []
invalid_payments: list[Payment] = []
invalid_wallets = {}
for db_payment in settled_db_payments:
if verbose:
@ -277,8 +277,10 @@ async def extensions_list():
from lnbits.app import build_all_installed_extensions_list
for ext in await build_all_installed_extensions_list():
assert ext.installed_release, f"Extension {ext.id} has no installed_release"
click.echo(f" - {ext.id} ({ext.installed_release.version})")
assert (
ext.meta and ext.meta.installed_release
), f"Extension {ext.id} has no installed_release"
click.echo(f" - {ext.id} ({ext.meta.installed_release.version})")
@extensions.command("update")
@ -461,7 +463,7 @@ async def install_extension(
source_repo: Optional[str] = None,
url: Optional[str] = None,
admin_user: Optional[str] = None,
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
try:
release = await _select_release(extension, repo_index, source_repo)
if not release:
@ -490,7 +492,7 @@ async def update_extension(
source_repo: Optional[str] = None,
url: Optional[str] = None,
admin_user: Optional[str] = None,
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
try:
click.echo(f"Updating '{extension}' extension.")
installed_ext = await get_installed_extension(extension)
@ -503,7 +505,7 @@ async def update_extension(
click.echo(f"Current '{extension}' version: {installed_ext.installed_version}.")
assert (
installed_ext.installed_release
installed_ext.meta and installed_ext.meta.installed_release
), "Cannot find previously installed release. Please uninstall first."
release = await _select_release(extension, repo_index, source_repo)
@ -511,7 +513,7 @@ async def update_extension(
return False, "No release selected."
if (
release.version == installed_ext.installed_version
and release.source_repo == installed_ext.installed_release.source_repo
and release.source_repo == installed_ext.meta.installed_release.source_repo
):
click.echo(f"Extension '{extension}' already up to date.")
return False, "Already up to date"

View File

@ -9,7 +9,6 @@ import shortuuid
from lnbits.core.db import db
from lnbits.core.extensions.models import (
InstallableExtension,
PayToEnableInfo,
UserExtension,
)
from lnbits.core.models import PaymentState
@ -258,42 +257,18 @@ async def get_user(account: Account, conn: Optional[Connection] = None) -> User:
# -------
async def add_installed_extension(
async def create_installed_extension(
ext: InstallableExtension,
conn: Optional[Connection] = None,
) -> None:
meta = {
"installed_release": (
dict(ext.installed_release) if ext.installed_release else None
),
"pay_to_enable": (dict(ext.pay_to_enable) if ext.pay_to_enable else None),
"dependencies": ext.dependencies,
"payments": [dict(p) for p in ext.payments] if ext.payments else None,
}
await (conn or db).insert("installed_extensions", ext)
version = ext.installed_release.version if ext.installed_release else ""
await (conn or db).execute(
"""
INSERT INTO installed_extensions
(id, version, name, active, short_description, icon, stars, meta)
VALUES
(:ext, :version, :name, :active, :short_description, :icon, :stars, :meta)
ON CONFLICT (id) DO UPDATE SET
(version, name, active, short_description, icon, stars, meta) =
(:version, :name, :active, :short_description, :icon, :stars, :meta)
""",
{
"ext": ext.id,
"version": version,
"name": ext.name,
"active": ext.active,
"short_description": ext.short_description,
"icon": ext.icon,
"stars": ext.stars,
"meta": json.dumps(meta),
},
)
async def update_installed_extension(
ext: InstallableExtension,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).update("installed_extensions", ext)
async def update_installed_extension_state(
@ -307,17 +282,6 @@ async def update_installed_extension_state(
)
async def update_extension_pay_to_enable(
ext_id: str, payment_info: PayToEnableInfo, conn: Optional[Connection] = None
) -> None:
ext = await get_installed_extension(ext_id, conn)
if not ext:
return
ext.pay_to_enable = payment_info
await add_installed_extension(ext, conn)
async def delete_installed_extension(
*, ext_id: str, conn: Optional[Connection] = None
) -> None:

View File

@ -3,14 +3,13 @@ import importlib
from loguru import logger
from lnbits.core import core_app_extra
from lnbits.core.crud import (
add_installed_extension,
delete_installed_extension,
get_dbversions,
get_installed_extension,
update_installed_extension_state,
)
from lnbits.core.db import core_app_extra
from lnbits.core.helpers import migrate_extension_database
from lnbits.settings import settings
@ -18,22 +17,24 @@ from .models import Extension, InstallableExtension
async def install_extension(ext_info: InstallableExtension) -> Extension:
ext_id = ext_info.id
extension = Extension.from_installable_ext(ext_info)
installed_ext = await get_installed_extension(ext_info.id)
installed_ext = await get_installed_extension(ext_id)
ext_info.payments = installed_ext.payments if installed_ext else []
await ext_info.download_archive()
ext_info.extract_archive()
db_version = (await get_dbversions()).get(ext_info.id, 0)
db_version = (await get_dbversions()).get(ext_id, 0)
await migrate_extension_database(extension, db_version)
await add_installed_extension(ext_info)
# TODO: think about add installed
# await add_installed_extension(ext_info)
if extension.is_upgrade_extension:
# call stop while the old routes are still active
await stop_extension_background_work(ext_info.id)
await stop_extension_background_work(ext_id)
return extension

View File

@ -373,29 +373,37 @@ class ExtensionRelease(BaseModel):
return None
class ExtensionMeta(BaseModel):
installed_release: Optional[ExtensionRelease] = None
pay_to_enable: Optional[PayToEnableInfo] = None
payments: list[ReleasePaymentInfo] = []
dependencies: list[str] = []
class InstallableExtension(BaseModel):
id: str
name: str
version: str
active: Optional[bool] = False
short_description: Optional[str] = None
icon: Optional[str] = None
dependencies: list[str] = []
is_admin_only: bool = False
stars: int = 0
featured = False
latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] = None
payments: list[ReleasePaymentInfo] = []
pay_to_enable: Optional[PayToEnableInfo] = None
archive: Optional[str] = None
latest_release: Optional[ExtensionRelease] = None
meta: Optional[ExtensionMeta] = None
@property
def is_admin_only(self) -> bool:
return self.id in settings.lnbits_admin_extensions
@property
def hash(self) -> str:
if self.installed_release:
if self.installed_release.hash:
return self.installed_release.hash
if self.meta and self.meta.installed_release:
if self.meta.installed_release.hash:
return self.meta.installed_release.hash
m = hashlib.sha256()
m.update(f"{self.installed_release.archive}".encode())
m.update(f"{self.meta.installed_release.archive}".encode())
return m.hexdigest()
return "not-installed"
@ -433,15 +441,15 @@ class InstallableExtension(BaseModel):
@property
def installed_version(self) -> str:
if self.installed_release:
return self.installed_release.version
if self.meta and self.meta.installed_release:
return self.meta.installed_release.version
return ""
@property
def requires_payment(self) -> bool:
if not self.pay_to_enable:
if not self.meta or not self.meta.pay_to_enable:
return False
return self.pay_to_enable.required is True
return self.meta.pay_to_enable.required is True
async def download_archive(self):
logger.info(f"Downloading extension {self.name} ({self.installed_version}).")
@ -449,12 +457,14 @@ class InstallableExtension(BaseModel):
if ext_zip_file.is_file():
os.remove(ext_zip_file)
try:
assert self.installed_release, "installed_release is none."
assert (
self.meta and self.meta.installed_release
), "installed_release is none."
self._restore_payment_info()
await asyncio.to_thread(
download_url, self.installed_release.archive_url, ext_zip_file
download_url, self.meta.installed_release.archive_url, ext_zip_file
)
self._remember_payment_info()
@ -464,7 +474,11 @@ class InstallableExtension(BaseModel):
raise AssertionError("Cannot fetch extension archive file") from exc
archive_hash = file_hash(ext_zip_file)
if self.installed_release.hash and self.installed_release.hash != archive_hash:
if (
self.meta
and self.meta.installed_release.hash
and self.meta.installed_release.hash != archive_hash
):
# remove downloaded archive
if ext_zip_file.is_file():
os.remove(ext_zip_file)
@ -498,12 +512,13 @@ class InstallableExtension(BaseModel):
self.short_description = config_json.get("short_description")
if (
self.installed_release
and self.installed_release.is_github_release
self.meta
and self.meta.installed_release
and self.meta.installed_release.is_github_release
and config_json.get("tile")
):
self.icon = icon_to_github_url(
self.installed_release.source_repo, config_json.get("tile")
self.meta.installed_release.source_repo, config_json.get("tile")
)
shutil.rmtree(self.ext_dir, True)
@ -540,48 +555,34 @@ class InstallableExtension(BaseModel):
)
def _restore_payment_info(self):
if not self.installed_release:
if (
not self.meta
or not self.meta.installed_release
or not self.meta.installed_release.pay_link
or not self.meta.installed_release.payment_hash
):
return
if not self.installed_release.pay_link:
return
if self.installed_release.payment_hash:
return
payment_info = self.find_existing_payment(self.installed_release.pay_link)
payment_info = self.find_existing_payment(self.meta.installed_release.pay_link)
if payment_info:
self.installed_release.payment_hash = payment_info.payment_hash
self.meta.installed_release.payment_hash = payment_info.payment_hash
def _remember_payment_info(self):
if not self.installed_release or not self.installed_release.pay_link:
if (
not self.meta
or not self.meta.installed_release
or not self.meta.installed_release.pay_link
):
return
payment_info = ReleasePaymentInfo(
amount=self.installed_release.cost_sats,
pay_link=self.installed_release.pay_link,
payment_hash=self.installed_release.payment_hash,
amount=self.meta.installed_release.cost_sats,
pay_link=self.meta.installed_release.pay_link,
payment_hash=self.meta.installed_release.payment_hash,
)
self.payments = [
p for p in self.payments if p.pay_link != payment_info.pay_link
]
self.payments.append(payment_info)
@classmethod
def from_row(cls, data: dict) -> InstallableExtension:
meta = json.loads(data["meta"])
ext = InstallableExtension(**data)
if "installed_release" in meta:
ext.installed_release = ExtensionRelease(**meta["installed_release"])
if meta.get("pay_to_enable"):
ext.pay_to_enable = PayToEnableInfo(**meta["pay_to_enable"])
if meta.get("payments"):
ext.payments = [ReleasePaymentInfo(**p) for p in meta["payments"]]
return ext
@classmethod
def from_rows(cls, rows: Optional[list[Any]] = None) -> list[InstallableExtension]:
if rows is None:
rows = []
return [InstallableExtension.from_row(row) for row in rows]
@classmethod
async def from_github_release(
cls, github_release: GitHubRelease
@ -594,6 +595,7 @@ class InstallableExtension(BaseModel):
return InstallableExtension(
id=github_release.id,
name=config.name,
version=latest_release.tag_name,
short_description=config.short_description,
stars=int(repo.stargazers_count),
icon=icon_to_github_url(
@ -610,13 +612,15 @@ class InstallableExtension(BaseModel):
@classmethod
def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension:
meta = ExtensionMeta(dependencies=e.dependencies)
return InstallableExtension(
id=e.id,
name=e.name,
version=e.version,
archive=e.archive,
short_description=e.short_description,
icon=e.icon,
dependencies=e.dependencies,
meta=meta,
)
@classmethod

View File

@ -541,8 +541,6 @@ async def m021_add_success_failed_to_apipayments(db):
GROUP BY apipayments.wallet
"""
)
# TODO: drop column in next release
# await db.execute("ALTER TABLE apipayments DROP COLUMN pending")
async def m022_add_pubkey_to_accounts(db):
@ -559,6 +557,7 @@ async def m023_add_column_column_to_apipayments(db):
"""
renames hash to payment_hash and drops unused index
"""
await db.execute("ALTER TABLE apipayments DROP COLUMN pending")
await db.execute("DROP INDEX by_hash")
await db.execute("ALTER TABLE apipayments RENAME COLUMN hash TO payment_hash")
await db.execute("ALTER TABLE apipayments RENAME COLUMN wallet TO wallet_id")

View File

@ -18,6 +18,7 @@ from lnbits.core.extensions.models import (
CreateExtension,
Extension,
ExtensionConfig,
ExtensionMeta,
ExtensionRelease,
InstallableExtension,
PayToEnableInfo,
@ -43,7 +44,7 @@ from ..crud import (
get_installed_extension,
get_installed_extensions,
get_user_extension,
update_extension_pay_to_enable,
update_installed_extension,
update_user_extension,
)
@ -69,8 +70,13 @@ async def api_install_extension(data: CreateExtension):
)
release.payment_hash = data.payment_hash
ext_meta = ExtensionMeta(installed_release=release)
ext_info = InstallableExtension(
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon
id=data.ext_id,
name=data.ext_id,
version=data.version,
meta=ext_meta,
icon=release.icon,
)
try:
@ -142,22 +148,21 @@ async def api_update_pay_to_enable(
data: PayToEnableInfo,
user: User = Depends(check_admin),
) -> SimpleStatus:
try:
assert (
data.wallet in user.wallet_ids
), "Wallet does not belong to this admin user."
await update_extension_pay_to_enable(ext_id, data)
return SimpleStatus(
success=True, message=f"Payment info updated for '{ext_id}' extension."
)
except AssertionError as exc:
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
except Exception as exc:
logger.warning(exc)
if data.wallet not in user.wallet_ids:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=(f"Failed to update pay to install data for extension '{ext_id}' "),
) from exc
HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user."
)
extension = await get_installed_extension(ext_id)
if not extension:
raise HTTPException(HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' not found.")
if extension.meta:
extension.meta.pay_to_enable = data
else:
extension.meta = ExtensionMeta(pay_to_enable=data)
await update_installed_extension(extension)
return SimpleStatus(
success=True, message=f"Payment info updated for '{ext_id}' extension."
)
@extension_router.put("/{ext_id}/enable")
@ -197,11 +202,11 @@ async def api_enable_extension(
)
assert (
ext.pay_to_enable and ext.pay_to_enable.wallet
ext.meta and ext.meta.pay_to_enable and ext.meta.pay_to_enable.wallet
), f"Extension '{ext_id}' is missing payment wallet."
payment_status = await check_transaction_status(
wallet_id=ext.pay_to_enable.wallet,
wallet_id=ext.meta.pay_to_enable.wallet,
payment_hash=user_ext.extra.payment_hash_to_enable,
)
@ -300,7 +305,11 @@ async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
installed_ext = next(
(ext for ext in installed_extensions if ext.id == valid_ext_id), None
)
if installed_ext and ext_id in installed_ext.dependencies:
if (
installed_ext
and installed_ext.meta
and ext_id in installed_ext.meta.dependencies
):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=(
@ -399,35 +408,35 @@ async def get_pay_to_enable_invoice(
status_code=HTTPStatus.NOT_FOUND, detail=f"Extension '{ext_id}' not found."
)
if not ext.pay_to_enable:
if not ext.meta or not ext.meta.pay_to_enable:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Payment info not found for extension '{ext_id}'.",
)
if not ext.pay_to_enable.required:
if not ext.meta.pay_to_enable.required:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Payment not required for extension '{ext_id}'.",
)
if not ext.pay_to_enable.wallet or not ext.pay_to_enable.amount:
if not ext.meta.pay_to_enable.wallet or not ext.meta.pay_to_enable.amount:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Payment wallet or amount missing for extension '{ext_id}'.",
)
if data.amount < ext.pay_to_enable.amount:
if data.amount < ext.meta.pay_to_enable.amount:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=(
f"Amount {data.amount} sats is less than required "
f"{ext.pay_to_enable.amount} sats."
f"{ext.meta.pay_to_enable.amount} sats."
),
)
payment_hash, payment_request = await create_invoice(
wallet_id=ext.pay_to_enable.wallet,
wallet_id=ext.meta.pay_to_enable.wallet,
amount=data.amount,
memo=f"Enable '{ext.name}' extension.",
)

View File

@ -12,7 +12,7 @@ from lnurl import decode as lnurl_decode
from loguru import logger
from pydantic.types import UUID4
from lnbits.core.extensions.models import Extension, InstallableExtension
from lnbits.core.extensions.models import Extension, ExtensionMeta, InstallableExtension
from lnbits.core.helpers import to_valid_user_id
from lnbits.core.models import User
from lnbits.core.services import create_invoice
@ -88,13 +88,21 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
for e in installable_exts:
installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None)
if installed_ext:
e.installed_release = installed_ext.installed_release
if installed_ext.pay_to_enable and not user.admin:
if installed_ext and installed_ext.meta:
installed_release = installed_ext.meta.installed_release
if installed_ext.meta.pay_to_enable and not user.admin:
# not a security leak, but better not to share the wallet id
installed_ext.pay_to_enable.wallet = None
e.pay_to_enable = installed_ext.pay_to_enable
installed_ext.meta.pay_to_enable.wallet = None
pay_to_enable = installed_ext.meta.pay_to_enable
if e.meta:
e.meta.installed_release = installed_release
e.meta.pay_to_enable = pay_to_enable
else:
e.meta = ExtensionMeta(
installed_release=installed_release,
pay_to_enable=pay_to_enable,
)
# use the installed extension values
e.name = installed_ext.name
e.short_description = installed_ext.short_description
@ -119,7 +127,7 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
"shortDescription": ext.short_description,
"stars": ext.stars,
"isFeatured": ext.featured,
"dependencies": ext.dependencies,
"dependencies": ext.meta.dependencies if ext.meta else "",
"isInstalled": ext.id in installed_exts_ids,
"hasDatabaseTables": ext.id in db_version,
"isAvailable": ext.id in all_ext_ids,
@ -129,9 +137,15 @@ async def extensions(request: Request, user: User = Depends(check_user_exists)):
dict(ext.latest_release) if ext.latest_release else None
),
"installedRelease": (
dict(ext.installed_release) if ext.installed_release else None
dict(ext.meta.installed_release)
if ext.meta and ext.meta.installed_release
else None
),
"payToEnable": (
dict(ext.meta.pay_to_enable)
if ext.meta and ext.meta.pay_to_enable
else {}
),
"payToEnable": (dict(ext.pay_to_enable) if ext.pay_to_enable else {}),
"isPaymentRequired": ext.requires_payment,
}
for ext in installable_exts

View File

@ -611,7 +611,7 @@ def model_to_dict(model: BaseModel) -> dict:
return _dict
def dict_to_model(_dict: dict, model: type[TModel]) -> TModel:
def dict_to_model(_row: dict, model: type[TModel]) -> TModel:
"""
Convert a dictionary with JSON-encoded nested models to a Pydantic model
:param _dict: Dictionary from database
@ -619,9 +619,10 @@ def dict_to_model(_dict: dict, model: type[TModel]) -> TModel:
"""
# TODO: no recursion, maybe make them recursive?
# TODO: check why keys are sometimes not in the dict
for key, value in _dict.items():
_dict = {}
for key, value in _row.items():
if key not in model.__fields__:
# logger.warning(f"Converting {key} to model `{model}`.")
logger.warning(f"Converting {key} to model `{model}`.")
continue
type_ = model.__fields__[key].type_
if issubclass(type_, BaseModel):