refactor: more and more

This commit is contained in:
Vlad Stan
2023-01-17 16:28:24 +02:00
parent 77c17a2b63
commit 86c86958ae
5 changed files with 56 additions and 71 deletions

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
from lnbits import bolt11 from lnbits import bolt11
from lnbits.db import COCKROACH, POSTGRES, Connection from lnbits.db import COCKROACH, POSTGRES, Connection
from lnbits.extension_manger import InstallableExtension
from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, settings
from . import db from . import db
@@ -71,29 +72,39 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
async def add_installed_extension( async def add_installed_extension(
*, ext: InstallableExtension,
ext_id: str,
version: str,
name: str,
active: bool,
meta: dict,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ) -> None:
meta = {
"installed_release": dict(ext.installed_release)
if ext.installed_release
else None,
"dependencies": ext.dependencies,
}
version = ext.installed_release.version if ext.installed_release else ""
await (conn or db).execute( await (conn or db).execute(
""" """
INSERT INTO installed_extensions (id, version, name, active, meta) VALUES (?, ?, ?, ?, ?) INSERT INTO installed_extensions (id, version, name, short_description, icon, icon_url, stars, meta) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (id) DO ON CONFLICT (id) DO
UPDATE SET (version, name, active, meta) = (?, ?, ?, ?) UPDATE SET (version, name, short_description, icon, icon_url, stars, meta) = (?, ?, ?, ?, ?, ?, ?)
""", """,
( (
ext_id, ext.id,
version, version,
name, ext.name,
active, ext.short_description,
ext.icon,
ext.icon_url,
ext.stars,
json.dumps(meta), json.dumps(meta),
version, version,
name, ext.name,
active, ext.short_description,
ext.icon,
ext.icon_url,
ext.stars,
json.dumps(meta), json.dumps(meta),
), ),
) )
@@ -132,6 +143,16 @@ async def get_installed_extension(ext_id: str, conn: Optional[Connection] = None
return dict(row) return dict(row)
async def get_installed_extensions(
conn: Optional[Connection] = None,
) -> List["InstallableExtension"]:
rows = await (conn or db).fetchall(
"SELECT * FROM installed_extensions",
(),
)
return [InstallableExtension.from_row(row) for row in rows]
async def get_inactive_extensions(*, conn: Optional[Connection] = None) -> List[str]: async def get_inactive_extensions(*, conn: Optional[Connection] = None) -> List[str]:
inactive_extensions = await (conn or db).fetchall( inactive_extensions = await (conn or db).fetchall(
"""SELECT id FROM installed_extensions WHERE NOT active""", """SELECT id FROM installed_extensions WHERE NOT active""",

View File

@@ -278,6 +278,10 @@ async def m009_create_installed_extensions_table(db):
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
version TEXT NOT NULL, version TEXT NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
short_description TEXT,
icon TEXT,
icon_url TEXT,
stars INT NOT NULL DEFAULT 0,
active BOOLEAN DEFAULT false, active BOOLEAN DEFAULT false,
meta TEXT NOT NULL DEFAULT '{}' meta TEXT NOT NULL DEFAULT '{}'
); );

View File

@@ -747,13 +747,7 @@ async def api_install_extension(
db_version = (await get_dbversions()).get(data.ext_id, 0) db_version = (await get_dbversions()).get(data.ext_id, 0)
await migrate_extension_database(extension, db_version) await migrate_extension_database(extension, db_version)
await add_installed_extension( await add_installed_extension(ext_info)
ext_id=data.ext_id,
version=release.version,
name=ext_info.name,
active=False,
meta={"installed_release": dict(release)},
)
settings.lnbits_disabled_extensions += [data.ext_id] settings.lnbits_disabled_extensions += [data.ext_id]
# mount routes for the new version # mount routes for the new version

View File

@@ -23,6 +23,7 @@ from ..crud import (
delete_wallet, delete_wallet,
get_balance_check, get_balance_check,
get_inactive_extensions, get_inactive_extensions,
get_installed_extensions,
get_user, get_user,
save_balance_notify, save_balance_notify,
update_installed_extension_state, update_installed_extension_state,
@@ -75,9 +76,12 @@ async def extensions_install(
deactivate: str = Query(None), deactivate: str = Query(None),
): ):
try: try:
installed_extensions: List[
"InstallableExtension"
] = await get_installed_extensions()
extension_list: List[ extension_list: List[
InstallableExtension InstallableExtension
] = await InstallableExtension.get_installable_extensions() ] = await InstallableExtension.get_installable_extensions(installed_extensions)
except Exception as ex: except Exception as ex:
logger.warning(ex) logger.warning(ex)
extension_list = [] extension_list = []
@@ -96,7 +100,7 @@ async def extensions_install(
ext_id=ext_id, active=activate != None ext_id=ext_id, active=activate != None
) )
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) all_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
inactive_extensions = await get_inactive_extensions() inactive_extensions = await get_inactive_extensions()
extensions = list( extensions = list(
map( map(
@@ -108,7 +112,7 @@ async def extensions_install(
"shortDescription": ext.short_description, "shortDescription": ext.short_description,
"stars": ext.stars, "stars": ext.stars,
"dependencies": ext.dependencies, "dependencies": ext.dependencies,
"isInstalled": ext.id in installed_extensions, "isInstalled": ext.id in all_extensions,
"isActive": not ext.id in inactive_extensions, "isActive": not ext.id in inactive_extensions,
"latestRelease": dict(ext.latest_release) "latestRelease": dict(ext.latest_release)
if ext.latest_release if ext.latest_release

View File

@@ -15,7 +15,6 @@ from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.core.crud import get_installed_extension
from lnbits.settings import settings from lnbits.settings import settings
@@ -255,9 +254,7 @@ class InstallableExtension(BaseModel):
shutil.rmtree(self.ext_upgrade_dir, True) shutil.rmtree(self.ext_upgrade_dir, True)
@classmethod @classmethod
async def from_row(cls, data: dict) -> Optional["InstallableExtension"]: def from_row(cls, data: dict) -> "InstallableExtension":
if not data:
return None
meta = json.loads(data["meta"]) meta = json.loads(data["meta"])
ext = InstallableExtension(**data) ext = InstallableExtension(**data)
if "installed_release" in meta: if "installed_release" in meta:
@@ -269,9 +266,6 @@ class InstallableExtension(BaseModel):
cls, ext_id, org, repository cls, ext_id, org, repository
) -> Optional["InstallableExtension"]: ) -> Optional["InstallableExtension"]:
try: try:
installed_release = await InstallableExtension.from_row(
await get_installed_extension(ext_id)
)
repo, latest_release, config = await fetch_github_repo_info(org, repository) repo, latest_release, config = await fetch_github_repo_info(org, repository)
return InstallableExtension( return InstallableExtension(
@@ -281,7 +275,6 @@ class InstallableExtension(BaseModel):
version="0", version="0",
stars=repo["stargazers_count"], stars=repo["stargazers_count"],
icon_url=icon_to_github_url(org, config.get("tile")), icon_url=icon_to_github_url(org, config.get("tile")),
installed_release=installed_release,
latest_release=ExtensionRelease.from_github_release( latest_release=ExtensionRelease.from_github_release(
repo["html_url"], latest_release repo["html_url"], latest_release
), ),
@@ -291,10 +284,7 @@ class InstallableExtension(BaseModel):
return None return None
@classmethod @classmethod
async def from_manifest(cls, e: dict) -> "InstallableExtension": def from_manifest(cls, e: dict) -> "InstallableExtension":
installed_ext = await InstallableExtension.from_row(
await get_installed_extension(e["id"])
)
return InstallableExtension( return InstallableExtension(
id=e["id"], id=e["id"],
name=e["name"], name=e["name"],
@@ -302,42 +292,17 @@ class InstallableExtension(BaseModel):
hash=e["hash"], hash=e["hash"],
short_description=e["shortDescription"], short_description=e["shortDescription"],
icon=e["icon"], icon=e["icon"],
installed_release=installed_ext.installed_release
if installed_ext
else None,
dependencies=e["dependencies"] if "dependencies" in e else [], dependencies=e["dependencies"] if "dependencies" in e else [],
) )
@classmethod # todo: remove
async def get_extension_info(
cls, ext_id: str, archive: str
) -> "InstallableExtension":
installable_extensions: List[
InstallableExtension
] = await InstallableExtension.get_installable_extensions()
valid_extensions = [e for e in installable_extensions if e.id == ext_id]
if len(valid_extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = valid_extensions[0]
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
return extension
@classmethod @classmethod
async def get_installable_extensions(cls) -> List["InstallableExtension"]: async def get_installable_extensions(
extension_list: List[InstallableExtension] = [] cls, installed_extensions: List["InstallableExtension"] = []
extension_id_list: List[str] = [] ) -> List["InstallableExtension"]:
extension_list: List[InstallableExtension] = (
installed_extensions if installed_extensions else []
)
extension_id_list: List[str] = [e.id for e in extension_list]
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
for url in settings.lnbits_extensions_manifests: for url in settings.lnbits_extensions_manifests:
@@ -355,7 +320,6 @@ class InstallableExtension(BaseModel):
r["id"], r["organisation"], r["repository"] r["id"], r["organisation"], r["repository"]
) )
if ext: if ext:
extension_list += [ext] extension_list += [ext]
extension_id_list += [ext.id] extension_id_list += [ext.id]
@@ -363,9 +327,7 @@ class InstallableExtension(BaseModel):
for e in manifest["extensions"]: for e in manifest["extensions"]:
if e["id"] in extension_id_list: if e["id"] in extension_id_list:
continue continue
extension_list += [ extension_list += [InstallableExtension.from_manifest(e)]
await InstallableExtension.from_manifest(e)
]
extension_id_list += [e["id"]] extension_id_list += [e["id"]]
except Exception as e: except Exception as e:
logger.warning(f"Manifest {url} failed with '{str(e)}'") logger.warning(f"Manifest {url} failed with '{str(e)}'")