diff --git a/lnbits/bolt11.py b/lnbits/bolt11.py index a918cfbad..628c63c58 100644 --- a/lnbits/bolt11.py +++ b/lnbits/bolt11.py @@ -9,6 +9,7 @@ import secp256k1 from bech32 import CHARSET, bech32_decode, bech32_encode from ecdsa import SECP256k1, VerifyingKey from ecdsa.util import sigdecode_string +from loguru import logger class Route(NamedTuple): @@ -30,6 +31,7 @@ class Invoice: secret: Optional[str] = None route_hints: List[Route] = [] min_final_cltv_expiry: int = 18 + checking_id: Optional[str] = None def decode(pr: str) -> Invoice: @@ -66,11 +68,13 @@ def decode(pr: str) -> Invoice: invoice.amount_msat = _unshorten_amount(amountstr) # pull out date - invoice.date = data.read(35).uint + date_bin = data.read(35) + assert date_bin + invoice.date = date_bin.uint while data.pos != data.len: tag, tagdata, data = _pull_tagged(data) - data_length = len(tagdata) / 5 + data_length = len(tagdata or []) / 5 if tag == "d": invoice.description = _trim_to_bytes(tagdata).decode() @@ -89,12 +93,22 @@ def decode(pr: str) -> Invoice: elif tag == "r": s = bitstring.ConstBitStream(tagdata) while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: + pubkey = s.read(264) + assert pubkey + short_channel_id = s.read(64) + assert short_channel_id + base_fee_msat = s.read(32) + assert base_fee_msat + ppm_fee = s.read(32) + assert ppm_fee + cltv = s.read(16) + assert cltv route = Route( - pubkey=s.read(264).tobytes().hex(), - short_channel_id=_readable_scid(s.read(64).intbe), - base_fee_msat=s.read(32).intbe, - ppm_fee=s.read(32).intbe, - cltv=s.read(16).intbe, + pubkey=pubkey.tobytes().hex(), + short_channel_id=_readable_scid(short_channel_id.intbe), + base_fee_msat=base_fee_msat.intbe, + ppm_fee=ppm_fee.intbe, + cltv=cltv.intbe, ) invoice.route_hints.append(route) @@ -160,6 +174,10 @@ def encode(options): return lnencode(addr, options["privkey"]) +def encode_fallback(v, currency): + logger.error(f"hit bolt11.py encode_fallback with v: {v} and currency: {currency}") + + def lnencode(addr, privkey): if addr.amount: amount = Decimal(str(addr.amount)) @@ -244,7 +262,13 @@ def lnencode(addr, privkey): class LnAddr: def __init__( - self, paymenthash=None, amount=None, currency="bc", tags=None, date=None + self, + paymenthash=None, + amount=None, + currency="bc", + tags=None, + date=None, + fallback=None, ): self.date = int(time.time()) if not date else int(date) self.tags = [] if not tags else tags @@ -252,6 +276,7 @@ class LnAddr: self.paymenthash = paymenthash self.signature = None self.pubkey = None + self.fallback = fallback self.currency = currency self.amount = amount @@ -266,6 +291,7 @@ def shorten_amount(amount): # Convert to pico initially amount = int(amount * 10**12) units = ["p", "n", "u", "m", ""] + unit = "" for unit in units: if amount % 1000 == 0: amount //= 1000 @@ -304,14 +330,6 @@ def _pull_tagged(stream): return (CHARSET[tag], stream.read(length * 5), stream) -def is_p2pkh(currency, prefix): - return prefix == base58_prefix_map[currency][0] - - -def is_p2sh(currency, prefix): - return prefix == base58_prefix_map[currency][1] - - # Tagged field containing BitArray def tagged(char, l): # Tagged fields need to be zero-padded to 5 bits. @@ -359,5 +377,5 @@ def bitarray_to_u5(barr): ret = [] s = bitstring.ConstBitStream(barr) while s.pos != s.len: - ret.append(s.read(5).uint) + ret.append(s.read(5).uint) # type: ignore return ret diff --git a/lnbits/commands.py b/lnbits/commands.py index f2252fee1..e0f586051 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -41,6 +41,7 @@ async def migrate_databases(): """Creates the necessary databases if they don't exist already; or migrates them.""" async with core_db.connect() as conn: + exists = False if conn.type == SQLITE: exists = await conn.fetchone( "SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'" diff --git a/lnbits/db.py b/lnbits/db.py index 3af11e36c..4b50b82c7 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -131,7 +131,7 @@ class Database(Compat): else: self.type = POSTGRES - import psycopg2 + from psycopg2.extensions import DECIMAL, new_type, register_type def _parse_timestamp(value, _): if value is None: @@ -141,15 +141,15 @@ class Database(Compat): f = "%Y-%m-%d %H:%M:%S" return time.mktime(datetime.datetime.strptime(value, f).timetuple()) - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - psycopg2.extensions.DECIMAL.values, + register_type( + new_type( + DECIMAL.values, "DEC2FLOAT", lambda value, curs: float(value) if value is not None else None, ) ) - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( + register_type( + new_type( (1082, 1083, 1266), "DATE2INT", lambda value, curs: time.mktime(value.timetuple()) @@ -158,11 +158,7 @@ class Database(Compat): ) ) - psycopg2.extensions.register_type( - psycopg2.extensions.new_type( - (1184, 1114), "TIMESTAMP2INT", _parse_timestamp - ) - ) + register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp)) else: if os.path.isdir(settings.lnbits_data_folder): self.path = os.path.join( diff --git a/lnbits/decorators.py b/lnbits/decorators.py index bd1c05207..3ced881a9 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,14 +1,12 @@ from http import HTTPStatus from typing import Optional, Type -from fastapi import Security, status -from fastapi.exceptions import HTTPException +from fastapi import HTTPException, Request, Security, status from fastapi.openapi.models import APIKey, APIKeyIn -from fastapi.security.api_key import APIKeyHeader, APIKeyQuery +from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.security.base import SecurityBase from pydantic import BaseModel from pydantic.types import UUID4 -from starlette.requests import Request from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.models import User, Wallet @@ -17,9 +15,13 @@ from lnbits.requestvars import g from lnbits.settings import settings +# TODO: fix type ignores class KeyChecker(SecurityBase): def __init__( - self, scheme_name: str = None, auto_error: bool = True, api_key: str = None + self, + scheme_name: Optional[str] = None, + auto_error: bool = True, + api_key: Optional[str] = None, ): self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error @@ -27,13 +29,13 @@ class KeyChecker(SecurityBase): self._api_key = api_key if api_key: key = APIKey( - **{"in": APIKeyIn.query}, + **{"in": APIKeyIn.query}, # type: ignore name="X-API-KEY", description="Wallet API Key - QUERY", ) else: key = APIKey( - **{"in": APIKeyIn.header}, + **{"in": APIKeyIn.header}, # type: ignore name="X-API-KEY", description="Wallet API Key - HEADER", ) @@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker): """ def __init__( - self, scheme_name: str = None, auto_error: bool = True, api_key: str = None + self, + scheme_name: Optional[str] = None, + auto_error: bool = True, + api_key: Optional[str] = None, ): super().__init__(scheme_name, auto_error, api_key) self._key_type = "invoice" @@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker): """ def __init__( - self, scheme_name: str = None, auto_error: bool = True, api_key: str = None + self, + scheme_name: Optional[str] = None, + auto_error: bool = True, + api_key: Optional[str] = None, ): super().__init__(scheme_name, auto_error, api_key) self._key_type = "admin" diff --git a/lnbits/extension_manager.py b/lnbits/extension_manager.py index 81faff106..27bcb3992 100644 --- a/lnbits/extension_manager.py +++ b/lnbits/extension_manager.py @@ -3,20 +3,146 @@ import json import os import shutil import sys -import urllib.request import zipfile from http import HTTPStatus from pathlib import Path from typing import Any, List, NamedTuple, Optional, Tuple +from urllib import request import httpx -from fastapi.exceptions import HTTPException +from fastapi import HTTPException +from fastapi.responses import JSONResponse from loguru import logger from pydantic import BaseModel from lnbits.settings import settings +class ExplicitRelease(BaseModel): + id: str + name: str + version: str + archive: str + hash: str + dependencies: List[str] = [] + icon: Optional[str] + short_description: Optional[str] + html_url: Optional[str] + details: Optional[str] + info_notification: Optional[str] + critical_notification: Optional[str] + + +class GitHubRelease(BaseModel): + id: str + organisation: str + repository: str + + +class Manifest(BaseModel): + featured: List[str] = [] + extensions: List["ExplicitRelease"] = [] + repos: List["GitHubRelease"] = [] + + +class GitHubRepoRelease(BaseModel): + name: str + tag_name: str + zipball_url: str + html_url: str + + +class GitHubRepo(BaseModel): + stargazers_count: str + html_url: str + default_branch: str + + +class ExtensionConfig(BaseModel): + name: str + short_description: str + tile: str = "" + + +def download_url(url, save_path): + with request.urlopen(url) as dl_file: + with open(save_path, "wb") as out_file: + out_file.write(dl_file.read()) + + +def file_hash(filename): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() + + +async def fetch_github_repo_info( + org: str, repository: str +) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]: + repo_url = f"https://api.github.com/repos/{org}/{repository}" + error_msg = "Cannot fetch extension repo" + repo = await gihub_api_get(repo_url, error_msg) + github_repo = GitHubRepo.parse_obj(repo) + + lates_release_url = ( + f"https://api.github.com/repos/{org}/{repository}/releases/latest" + ) + error_msg = "Cannot fetch extension releases" + latest_release: Any = await gihub_api_get(lates_release_url, error_msg) + + config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json" + error_msg = "Cannot fetch config for extension" + config = await gihub_api_get(config_url, error_msg) + + return ( + github_repo, + GitHubRepoRelease.parse_obj(latest_release), + ExtensionConfig.parse_obj(config), + ) + + +async def fetch_manifest(url) -> Manifest: + error_msg = "Cannot fetch extensions manifest" + manifest = await gihub_api_get(url, error_msg) + return Manifest.parse_obj(manifest) + + +async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]: + releases_url = f"https://api.github.com/repos/{org}/{repo}/releases" + error_msg = "Cannot fetch extension releases" + releases = await gihub_api_get(releases_url, error_msg) + return [GitHubRepoRelease.parse_obj(r) for r in releases] + + +async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any: + async with httpx.AsyncClient() as client: + headers = ( + {"Authorization": "Bearer " + settings.lnbits_ext_github_token} + if settings.lnbits_ext_github_token + else None + ) + resp = await client.get( + url, + headers=headers, + ) + if resp.status_code != 200: + logger.warning(f"{error_msg} ({url}): {resp.text}") + resp.raise_for_status() + return resp.json() + + +def icon_to_github_url(source_repo: str, path: Optional[str]) -> str: + if not path: + return "" + _, _, *rest = path.split("/") + tail = "/".join(rest) + return f"https://github.com/{source_repo}/raw/main/{tail}" + + class Extension(NamedTuple): code: str is_valid: bool @@ -97,12 +223,12 @@ class ExtensionRelease(BaseModel): version: str archive: str source_repo: str - is_github_release = False - hash: Optional[str] - html_url: Optional[str] - description: Optional[str] + is_github_release: bool = False + hash: Optional[str] = None + html_url: Optional[str] = None + description: Optional[str] = None details_html: Optional[str] = None - icon: Optional[str] + icon: Optional[str] = None @classmethod def from_github_release( @@ -132,52 +258,6 @@ class ExtensionRelease(BaseModel): return [] -class ExplicitRelease(BaseModel): - id: str - name: str - version: str - archive: str - hash: str - dependencies: List[str] = [] - icon: Optional[str] - short_description: Optional[str] - html_url: Optional[str] - details: Optional[str] - info_notification: Optional[str] - critical_notification: Optional[str] - - -class GitHubRelease(BaseModel): - id: str - organisation: str - repository: str - - -class Manifest(BaseModel): - featured: List[str] = [] - extensions: List["ExplicitRelease"] = [] - repos: List["GitHubRelease"] = [] - - -class GitHubRepoRelease(BaseModel): - name: str - tag_name: str - zipball_url: str - html_url: str - - -class GitHubRepo(BaseModel): - stargazers_count: str - html_url: str - default_branch: str - - -class ExtensionConfig(BaseModel): - name: str - short_description: str - tile: str = "" - - class InstallableExtension(BaseModel): id: str name: str @@ -187,8 +267,9 @@ class InstallableExtension(BaseModel): is_admin_only: bool = False stars: int = 0 featured = False - latest_release: Optional[ExtensionRelease] - installed_release: Optional[ExtensionRelease] + latest_release: Optional[ExtensionRelease] = None + installed_release: Optional[ExtensionRelease] = None + archive: Optional[str] = None @property def hash(self) -> str: @@ -234,6 +315,7 @@ class InstallableExtension(BaseModel): if ext_zip_file.is_file(): os.remove(ext_zip_file) try: + assert self.installed_release download_url(self.installed_release.archive, ext_zip_file) except Exception as ex: logger.warning(ex) @@ -334,8 +416,7 @@ class InstallableExtension(BaseModel): id=github_release.id, name=config.name, short_description=config.short_description, - version="0", - stars=repo.stargazers_count, + stars=int(repo.stargazers_count), icon=icon_to_github_url( f"{github_release.organisation}/{github_release.repository}", config.tile, @@ -354,7 +435,6 @@ class InstallableExtension(BaseModel): id=e.id, name=e.name, archive=e.archive, - hash=e.hash, short_description=e.short_description, icon=e.icon, dependencies=e.dependencies, @@ -443,6 +523,52 @@ class InstallableExtension(BaseModel): return selected_release[0] if len(selected_release) != 0 else None +class InstalledExtensionMiddleware: + # This middleware class intercepts calls made to the extensions API and: + # - it blocks the calls if the extension has been disabled or uninstalled. + # - it redirects the calls to the latest version of the extension if the extension has been upgraded. + # - otherwise it has no effect + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if "path" not in scope: + await self.app(scope, receive, send) + return + + path_elements = scope["path"].split("/") + if len(path_elements) > 2: + _, path_name, path_type, *rest = path_elements + tail = "/".join(rest) + else: + _, path_name = path_elements + path_type = None + tail = "" + + # block path for all users if the extension is disabled + if path_name in settings.lnbits_deactivated_extensions: + response = JSONResponse( + status_code=HTTPStatus.NOT_FOUND, + content={"detail": f"Extension '{path_name}' disabled"}, + ) + await response(scope, receive, send) + return + + # re-route API trafic if the extension has been upgraded + if path_type == "api": + upgraded_extensions = list( + filter( + lambda ext: ext.endswith(f"/{path_name}"), + settings.lnbits_upgraded_extensions, + ) + ) + if len(upgraded_extensions) != 0: + upgrade_path = upgraded_extensions[0] + scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}" + + await self.app(scope, receive, send) + + class CreateExtension(BaseModel): ext_id: str archive: str diff --git a/lnbits/jinja2_templating.py b/lnbits/jinja2_templating.py index 5dfe36c36..f45394428 100644 --- a/lnbits/jinja2_templating.py +++ b/lnbits/jinja2_templating.py @@ -1,25 +1,18 @@ -# Borrowed from the excellent accent-starlette -# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py - import typing -from starlette import templating +from jinja2 import BaseLoader, Environment, pass_context from starlette.datastructures import QueryParams from starlette.requests import Request - -try: - import jinja2 -except ImportError: # pragma: nocover - jinja2 = None # type: ignore +from starlette.templating import Jinja2Templates as SuperJinja2Templates -class Jinja2Templates(templating.Jinja2Templates): - def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231 - assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" +class Jinja2Templates(SuperJinja2Templates): + def __init__(self, loader: BaseLoader) -> None: + super().__init__("") self.env = self.get_environment(loader) - def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": - @jinja2.pass_context + def get_environment(self, loader: BaseLoader) -> Environment: + @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request: Request = context["request"] return request.app.url_path_for(name, **path_params) @@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates): values.update(new) return QueryParams(**values) - env = jinja2.Environment(loader=loader, autoescape=True) + env = Environment(loader=loader, autoescape=True) env.globals["url_for"] = url_for env.globals["url_params_update"] = url_params_update return env diff --git a/lnbits/settings.py b/lnbits/settings.py index 62751b4f4..c5bdc43d4 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -24,6 +24,7 @@ def list_parse_fallback(v): class LNbitsSettings(BaseSettings): + @classmethod def validate(cls, val): if type(val) == str: val = val.split(",") if val else [] @@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings): class LNbitsFundingSource(LNbitsSettings): lnbits_endpoint: str = Field(default="https://legend.lnbits.com") lnbits_key: Optional[str] = Field(default=None) + lnbits_admin_key: Optional[str] = Field(default=None) + lnbits_invoice_key: Optional[str] = Field(default=None) class ClicheFundingSource(LNbitsSettings): @@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings): lnpay_api_endpoint: Optional[str] = Field(default=None) lnpay_api_key: Optional[str] = Field(default=None) lnpay_wallet_key: Optional[str] = Field(default=None) + lnpay_admin_key: Optional[str] = Field(default=None) class OpenNodeFundingSource(LNbitsSettings): opennode_api_endpoint: Optional[str] = Field(default=None) opennode_key: Optional[str] = Field(default=None) + opennode_admin_key: Optional[str] = Field(default=None) + opennode_invoice_key: Optional[str] = Field(default=None) class SparkFundingSource(LNbitsSettings): @@ -208,8 +214,9 @@ class EditableSettings( "lnbits_admin_extensions", pre=True, ) + @classmethod def validate_editable_settings(cls, val): - return super().validate(cls, val) + return super().validate(val) @classmethod def from_dict(cls, d: dict): @@ -281,8 +288,9 @@ class ReadOnlySettings( "lnbits_allowed_funding_sources", pre=True, ) + @classmethod def validate_readonly_settings(cls, val): - return super().validate(cls, val) + return super().validate(val) @classmethod def readonly_fields(cls): diff --git a/lnbits/tasks.py b/lnbits/tasks.py index f4d3bf7be..6c482256c 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -3,7 +3,7 @@ import time import traceback import uuid from http import HTTPStatus -from typing import Dict +from typing import Dict, Optional from fastapi.exceptions import HTTPException from loguru import logger @@ -42,7 +42,7 @@ class SseListenersDict(dict): A dict of sse listeners. """ - def __init__(self, name: str = None): + def __init__(self, name: Optional[str] = None): self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}" def __setitem__(self, key, value): @@ -65,7 +65,7 @@ class SseListenersDict(dict): invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners") -def register_invoice_listener(send_chan: asyncio.Queue, name: str = None): +def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None): """ A method intended for extensions (and core/tasks.py) to call when they want to be notified about new invoice payments incoming. Will emit all incoming payments. @@ -164,7 +164,7 @@ async def check_pending_payments(): async def perform_balance_checks(): while True: for bc in await get_balance_checks(): - redeem_lnurl_withdraw(bc.wallet, bc.url) + await redeem_lnurl_withdraw(bc.wallet, bc.url) await asyncio.sleep(60 * 60 * 6) # every 6 hours