mirror of
https://github.com/lnbits/lnbits.git
synced 2025-09-22 15:57:28 +02:00
@@ -9,6 +9,7 @@ import secp256k1
|
|||||||
from bech32 import CHARSET, bech32_decode, bech32_encode
|
from bech32 import CHARSET, bech32_decode, bech32_encode
|
||||||
from ecdsa import SECP256k1, VerifyingKey
|
from ecdsa import SECP256k1, VerifyingKey
|
||||||
from ecdsa.util import sigdecode_string
|
from ecdsa.util import sigdecode_string
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class Route(NamedTuple):
|
class Route(NamedTuple):
|
||||||
@@ -30,6 +31,7 @@ class Invoice:
|
|||||||
secret: Optional[str] = None
|
secret: Optional[str] = None
|
||||||
route_hints: List[Route] = []
|
route_hints: List[Route] = []
|
||||||
min_final_cltv_expiry: int = 18
|
min_final_cltv_expiry: int = 18
|
||||||
|
checking_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def decode(pr: str) -> Invoice:
|
def decode(pr: str) -> Invoice:
|
||||||
@@ -66,11 +68,13 @@ def decode(pr: str) -> Invoice:
|
|||||||
invoice.amount_msat = _unshorten_amount(amountstr)
|
invoice.amount_msat = _unshorten_amount(amountstr)
|
||||||
|
|
||||||
# pull out date
|
# pull out date
|
||||||
invoice.date = data.read(35).uint
|
date_bin = data.read(35)
|
||||||
|
assert date_bin
|
||||||
|
invoice.date = date_bin.uint
|
||||||
|
|
||||||
while data.pos != data.len:
|
while data.pos != data.len:
|
||||||
tag, tagdata, data = _pull_tagged(data)
|
tag, tagdata, data = _pull_tagged(data)
|
||||||
data_length = len(tagdata) / 5
|
data_length = len(tagdata or []) / 5
|
||||||
|
|
||||||
if tag == "d":
|
if tag == "d":
|
||||||
invoice.description = _trim_to_bytes(tagdata).decode()
|
invoice.description = _trim_to_bytes(tagdata).decode()
|
||||||
@@ -89,12 +93,22 @@ def decode(pr: str) -> Invoice:
|
|||||||
elif tag == "r":
|
elif tag == "r":
|
||||||
s = bitstring.ConstBitStream(tagdata)
|
s = bitstring.ConstBitStream(tagdata)
|
||||||
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
||||||
|
pubkey = s.read(264)
|
||||||
|
assert pubkey
|
||||||
|
short_channel_id = s.read(64)
|
||||||
|
assert short_channel_id
|
||||||
|
base_fee_msat = s.read(32)
|
||||||
|
assert base_fee_msat
|
||||||
|
ppm_fee = s.read(32)
|
||||||
|
assert ppm_fee
|
||||||
|
cltv = s.read(16)
|
||||||
|
assert cltv
|
||||||
route = Route(
|
route = Route(
|
||||||
pubkey=s.read(264).tobytes().hex(),
|
pubkey=pubkey.tobytes().hex(),
|
||||||
short_channel_id=_readable_scid(s.read(64).intbe),
|
short_channel_id=_readable_scid(short_channel_id.intbe),
|
||||||
base_fee_msat=s.read(32).intbe,
|
base_fee_msat=base_fee_msat.intbe,
|
||||||
ppm_fee=s.read(32).intbe,
|
ppm_fee=ppm_fee.intbe,
|
||||||
cltv=s.read(16).intbe,
|
cltv=cltv.intbe,
|
||||||
)
|
)
|
||||||
invoice.route_hints.append(route)
|
invoice.route_hints.append(route)
|
||||||
|
|
||||||
@@ -160,6 +174,10 @@ def encode(options):
|
|||||||
return lnencode(addr, options["privkey"])
|
return lnencode(addr, options["privkey"])
|
||||||
|
|
||||||
|
|
||||||
|
def encode_fallback(v, currency):
|
||||||
|
logger.error(f"hit bolt11.py encode_fallback with v: {v} and currency: {currency}")
|
||||||
|
|
||||||
|
|
||||||
def lnencode(addr, privkey):
|
def lnencode(addr, privkey):
|
||||||
if addr.amount:
|
if addr.amount:
|
||||||
amount = Decimal(str(addr.amount))
|
amount = Decimal(str(addr.amount))
|
||||||
@@ -244,7 +262,13 @@ def lnencode(addr, privkey):
|
|||||||
|
|
||||||
class LnAddr:
|
class LnAddr:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None
|
self,
|
||||||
|
paymenthash=None,
|
||||||
|
amount=None,
|
||||||
|
currency="bc",
|
||||||
|
tags=None,
|
||||||
|
date=None,
|
||||||
|
fallback=None,
|
||||||
):
|
):
|
||||||
self.date = int(time.time()) if not date else int(date)
|
self.date = int(time.time()) if not date else int(date)
|
||||||
self.tags = [] if not tags else tags
|
self.tags = [] if not tags else tags
|
||||||
@@ -252,6 +276,7 @@ class LnAddr:
|
|||||||
self.paymenthash = paymenthash
|
self.paymenthash = paymenthash
|
||||||
self.signature = None
|
self.signature = None
|
||||||
self.pubkey = None
|
self.pubkey = None
|
||||||
|
self.fallback = fallback
|
||||||
self.currency = currency
|
self.currency = currency
|
||||||
self.amount = amount
|
self.amount = amount
|
||||||
|
|
||||||
@@ -266,6 +291,7 @@ def shorten_amount(amount):
|
|||||||
# Convert to pico initially
|
# Convert to pico initially
|
||||||
amount = int(amount * 10**12)
|
amount = int(amount * 10**12)
|
||||||
units = ["p", "n", "u", "m", ""]
|
units = ["p", "n", "u", "m", ""]
|
||||||
|
unit = ""
|
||||||
for unit in units:
|
for unit in units:
|
||||||
if amount % 1000 == 0:
|
if amount % 1000 == 0:
|
||||||
amount //= 1000
|
amount //= 1000
|
||||||
@@ -304,14 +330,6 @@ def _pull_tagged(stream):
|
|||||||
return (CHARSET[tag], stream.read(length * 5), stream)
|
return (CHARSET[tag], stream.read(length * 5), stream)
|
||||||
|
|
||||||
|
|
||||||
def is_p2pkh(currency, prefix):
|
|
||||||
return prefix == base58_prefix_map[currency][0]
|
|
||||||
|
|
||||||
|
|
||||||
def is_p2sh(currency, prefix):
|
|
||||||
return prefix == base58_prefix_map[currency][1]
|
|
||||||
|
|
||||||
|
|
||||||
# Tagged field containing BitArray
|
# Tagged field containing BitArray
|
||||||
def tagged(char, l):
|
def tagged(char, l):
|
||||||
# Tagged fields need to be zero-padded to 5 bits.
|
# Tagged fields need to be zero-padded to 5 bits.
|
||||||
@@ -359,5 +377,5 @@ def bitarray_to_u5(barr):
|
|||||||
ret = []
|
ret = []
|
||||||
s = bitstring.ConstBitStream(barr)
|
s = bitstring.ConstBitStream(barr)
|
||||||
while s.pos != s.len:
|
while s.pos != s.len:
|
||||||
ret.append(s.read(5).uint)
|
ret.append(s.read(5).uint) # type: ignore
|
||||||
return ret
|
return ret
|
||||||
|
@@ -41,6 +41,7 @@ async def migrate_databases():
|
|||||||
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
||||||
|
|
||||||
async with core_db.connect() as conn:
|
async with core_db.connect() as conn:
|
||||||
|
exists = False
|
||||||
if conn.type == SQLITE:
|
if conn.type == SQLITE:
|
||||||
exists = await conn.fetchone(
|
exists = await conn.fetchone(
|
||||||
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
||||||
|
18
lnbits/db.py
18
lnbits/db.py
@@ -131,7 +131,7 @@ class Database(Compat):
|
|||||||
else:
|
else:
|
||||||
self.type = POSTGRES
|
self.type = POSTGRES
|
||||||
|
|
||||||
import psycopg2
|
from psycopg2.extensions import DECIMAL, new_type, register_type
|
||||||
|
|
||||||
def _parse_timestamp(value, _):
|
def _parse_timestamp(value, _):
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -141,15 +141,15 @@ class Database(Compat):
|
|||||||
f = "%Y-%m-%d %H:%M:%S"
|
f = "%Y-%m-%d %H:%M:%S"
|
||||||
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||||
|
|
||||||
psycopg2.extensions.register_type(
|
register_type(
|
||||||
psycopg2.extensions.new_type(
|
new_type(
|
||||||
psycopg2.extensions.DECIMAL.values,
|
DECIMAL.values,
|
||||||
"DEC2FLOAT",
|
"DEC2FLOAT",
|
||||||
lambda value, curs: float(value) if value is not None else None,
|
lambda value, curs: float(value) if value is not None else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
psycopg2.extensions.register_type(
|
register_type(
|
||||||
psycopg2.extensions.new_type(
|
new_type(
|
||||||
(1082, 1083, 1266),
|
(1082, 1083, 1266),
|
||||||
"DATE2INT",
|
"DATE2INT",
|
||||||
lambda value, curs: time.mktime(value.timetuple())
|
lambda value, curs: time.mktime(value.timetuple())
|
||||||
@@ -158,11 +158,7 @@ class Database(Compat):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
psycopg2.extensions.register_type(
|
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
|
||||||
psycopg2.extensions.new_type(
|
|
||||||
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if os.path.isdir(settings.lnbits_data_folder):
|
if os.path.isdir(settings.lnbits_data_folder):
|
||||||
self.path = os.path.join(
|
self.path = os.path.join(
|
||||||
|
@@ -1,14 +1,12 @@
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from fastapi import Security, status
|
from fastapi import HTTPException, Request, Security, status
|
||||||
from fastapi.exceptions import HTTPException
|
|
||||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.types import UUID4
|
from pydantic.types import UUID4
|
||||||
from starlette.requests import Request
|
|
||||||
|
|
||||||
from lnbits.core.crud import get_user, get_wallet_for_key
|
from lnbits.core.crud import get_user, get_wallet_for_key
|
||||||
from lnbits.core.models import User, Wallet
|
from lnbits.core.models import User, Wallet
|
||||||
@@ -17,9 +15,13 @@ from lnbits.requestvars import g
|
|||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: fix type ignores
|
||||||
class KeyChecker(SecurityBase):
|
class KeyChecker(SecurityBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
self,
|
||||||
|
scheme_name: Optional[str] = None,
|
||||||
|
auto_error: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.scheme_name = scheme_name or self.__class__.__name__
|
||||||
self.auto_error = auto_error
|
self.auto_error = auto_error
|
||||||
@@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
|
|||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
if api_key:
|
if api_key:
|
||||||
key = APIKey(
|
key = APIKey(
|
||||||
**{"in": APIKeyIn.query},
|
**{"in": APIKeyIn.query}, # type: ignore
|
||||||
name="X-API-KEY",
|
name="X-API-KEY",
|
||||||
description="Wallet API Key - QUERY",
|
description="Wallet API Key - QUERY",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key = APIKey(
|
key = APIKey(
|
||||||
**{"in": APIKeyIn.header},
|
**{"in": APIKeyIn.header}, # type: ignore
|
||||||
name="X-API-KEY",
|
name="X-API-KEY",
|
||||||
description="Wallet API Key - HEADER",
|
description="Wallet API Key - HEADER",
|
||||||
)
|
)
|
||||||
@@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
self,
|
||||||
|
scheme_name: Optional[str] = None,
|
||||||
|
auto_error: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
super().__init__(scheme_name, auto_error, api_key)
|
||||||
self._key_type = "invoice"
|
self._key_type = "invoice"
|
||||||
@@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None
|
self,
|
||||||
|
scheme_name: Optional[str] = None,
|
||||||
|
auto_error: bool = True,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
super().__init__(scheme_name, auto_error, api_key)
|
||||||
self._key_type = "admin"
|
self._key_type = "admin"
|
||||||
|
@@ -3,20 +3,146 @@ import json
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import urllib.request
|
|
||||||
import zipfile
|
import zipfile
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class ExplicitRelease(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
archive: str
|
||||||
|
hash: str
|
||||||
|
dependencies: List[str] = []
|
||||||
|
icon: Optional[str]
|
||||||
|
short_description: Optional[str]
|
||||||
|
html_url: Optional[str]
|
||||||
|
details: Optional[str]
|
||||||
|
info_notification: Optional[str]
|
||||||
|
critical_notification: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRelease(BaseModel):
|
||||||
|
id: str
|
||||||
|
organisation: str
|
||||||
|
repository: str
|
||||||
|
|
||||||
|
|
||||||
|
class Manifest(BaseModel):
|
||||||
|
featured: List[str] = []
|
||||||
|
extensions: List["ExplicitRelease"] = []
|
||||||
|
repos: List["GitHubRelease"] = []
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRepoRelease(BaseModel):
|
||||||
|
name: str
|
||||||
|
tag_name: str
|
||||||
|
zipball_url: str
|
||||||
|
html_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubRepo(BaseModel):
|
||||||
|
stargazers_count: str
|
||||||
|
html_url: str
|
||||||
|
default_branch: str
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionConfig(BaseModel):
|
||||||
|
name: str
|
||||||
|
short_description: str
|
||||||
|
tile: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(url, save_path):
|
||||||
|
with request.urlopen(url) as dl_file:
|
||||||
|
with open(save_path, "wb") as out_file:
|
||||||
|
out_file.write(dl_file.read())
|
||||||
|
|
||||||
|
|
||||||
|
def file_hash(filename):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
b = bytearray(128 * 1024)
|
||||||
|
mv = memoryview(b)
|
||||||
|
with open(filename, "rb", buffering=0) as f:
|
||||||
|
while n := f.readinto(mv):
|
||||||
|
h.update(mv[:n])
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_github_repo_info(
|
||||||
|
org: str, repository: str
|
||||||
|
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
|
||||||
|
repo_url = f"https://api.github.com/repos/{org}/{repository}"
|
||||||
|
error_msg = "Cannot fetch extension repo"
|
||||||
|
repo = await gihub_api_get(repo_url, error_msg)
|
||||||
|
github_repo = GitHubRepo.parse_obj(repo)
|
||||||
|
|
||||||
|
lates_release_url = (
|
||||||
|
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
|
||||||
|
)
|
||||||
|
error_msg = "Cannot fetch extension releases"
|
||||||
|
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
|
||||||
|
|
||||||
|
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
|
||||||
|
error_msg = "Cannot fetch config for extension"
|
||||||
|
config = await gihub_api_get(config_url, error_msg)
|
||||||
|
|
||||||
|
return (
|
||||||
|
github_repo,
|
||||||
|
GitHubRepoRelease.parse_obj(latest_release),
|
||||||
|
ExtensionConfig.parse_obj(config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_manifest(url) -> Manifest:
|
||||||
|
error_msg = "Cannot fetch extensions manifest"
|
||||||
|
manifest = await gihub_api_get(url, error_msg)
|
||||||
|
return Manifest.parse_obj(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
|
||||||
|
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
|
||||||
|
error_msg = "Cannot fetch extension releases"
|
||||||
|
releases = await gihub_api_get(releases_url, error_msg)
|
||||||
|
return [GitHubRepoRelease.parse_obj(r) for r in releases]
|
||||||
|
|
||||||
|
|
||||||
|
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
headers = (
|
||||||
|
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
|
||||||
|
if settings.lnbits_ext_github_token
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
resp = await client.get(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.warning(f"{error_msg} ({url}): {resp.text}")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
|
||||||
|
if not path:
|
||||||
|
return ""
|
||||||
|
_, _, *rest = path.split("/")
|
||||||
|
tail = "/".join(rest)
|
||||||
|
return f"https://github.com/{source_repo}/raw/main/{tail}"
|
||||||
|
|
||||||
|
|
||||||
class Extension(NamedTuple):
|
class Extension(NamedTuple):
|
||||||
code: str
|
code: str
|
||||||
is_valid: bool
|
is_valid: bool
|
||||||
@@ -97,12 +223,12 @@ class ExtensionRelease(BaseModel):
|
|||||||
version: str
|
version: str
|
||||||
archive: str
|
archive: str
|
||||||
source_repo: str
|
source_repo: str
|
||||||
is_github_release = False
|
is_github_release: bool = False
|
||||||
hash: Optional[str]
|
hash: Optional[str] = None
|
||||||
html_url: Optional[str]
|
html_url: Optional[str] = None
|
||||||
description: Optional[str]
|
description: Optional[str] = None
|
||||||
details_html: Optional[str] = None
|
details_html: Optional[str] = None
|
||||||
icon: Optional[str]
|
icon: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_github_release(
|
def from_github_release(
|
||||||
@@ -132,52 +258,6 @@ class ExtensionRelease(BaseModel):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class ExplicitRelease(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
version: str
|
|
||||||
archive: str
|
|
||||||
hash: str
|
|
||||||
dependencies: List[str] = []
|
|
||||||
icon: Optional[str]
|
|
||||||
short_description: Optional[str]
|
|
||||||
html_url: Optional[str]
|
|
||||||
details: Optional[str]
|
|
||||||
info_notification: Optional[str]
|
|
||||||
critical_notification: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubRelease(BaseModel):
|
|
||||||
id: str
|
|
||||||
organisation: str
|
|
||||||
repository: str
|
|
||||||
|
|
||||||
|
|
||||||
class Manifest(BaseModel):
|
|
||||||
featured: List[str] = []
|
|
||||||
extensions: List["ExplicitRelease"] = []
|
|
||||||
repos: List["GitHubRelease"] = []
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubRepoRelease(BaseModel):
|
|
||||||
name: str
|
|
||||||
tag_name: str
|
|
||||||
zipball_url: str
|
|
||||||
html_url: str
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubRepo(BaseModel):
|
|
||||||
stargazers_count: str
|
|
||||||
html_url: str
|
|
||||||
default_branch: str
|
|
||||||
|
|
||||||
|
|
||||||
class ExtensionConfig(BaseModel):
|
|
||||||
name: str
|
|
||||||
short_description: str
|
|
||||||
tile: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class InstallableExtension(BaseModel):
|
class InstallableExtension(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@@ -187,8 +267,9 @@ class InstallableExtension(BaseModel):
|
|||||||
is_admin_only: bool = False
|
is_admin_only: bool = False
|
||||||
stars: int = 0
|
stars: int = 0
|
||||||
featured = False
|
featured = False
|
||||||
latest_release: Optional[ExtensionRelease]
|
latest_release: Optional[ExtensionRelease] = None
|
||||||
installed_release: Optional[ExtensionRelease]
|
installed_release: Optional[ExtensionRelease] = None
|
||||||
|
archive: Optional[str] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hash(self) -> str:
|
def hash(self) -> str:
|
||||||
@@ -234,6 +315,7 @@ class InstallableExtension(BaseModel):
|
|||||||
if ext_zip_file.is_file():
|
if ext_zip_file.is_file():
|
||||||
os.remove(ext_zip_file)
|
os.remove(ext_zip_file)
|
||||||
try:
|
try:
|
||||||
|
assert self.installed_release
|
||||||
download_url(self.installed_release.archive, ext_zip_file)
|
download_url(self.installed_release.archive, ext_zip_file)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(ex)
|
logger.warning(ex)
|
||||||
@@ -334,8 +416,7 @@ class InstallableExtension(BaseModel):
|
|||||||
id=github_release.id,
|
id=github_release.id,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
short_description=config.short_description,
|
short_description=config.short_description,
|
||||||
version="0",
|
stars=int(repo.stargazers_count),
|
||||||
stars=repo.stargazers_count,
|
|
||||||
icon=icon_to_github_url(
|
icon=icon_to_github_url(
|
||||||
f"{github_release.organisation}/{github_release.repository}",
|
f"{github_release.organisation}/{github_release.repository}",
|
||||||
config.tile,
|
config.tile,
|
||||||
@@ -354,7 +435,6 @@ class InstallableExtension(BaseModel):
|
|||||||
id=e.id,
|
id=e.id,
|
||||||
name=e.name,
|
name=e.name,
|
||||||
archive=e.archive,
|
archive=e.archive,
|
||||||
hash=e.hash,
|
|
||||||
short_description=e.short_description,
|
short_description=e.short_description,
|
||||||
icon=e.icon,
|
icon=e.icon,
|
||||||
dependencies=e.dependencies,
|
dependencies=e.dependencies,
|
||||||
@@ -443,6 +523,52 @@ class InstallableExtension(BaseModel):
|
|||||||
return selected_release[0] if len(selected_release) != 0 else None
|
return selected_release[0] if len(selected_release) != 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
class InstalledExtensionMiddleware:
|
||||||
|
# This middleware class intercepts calls made to the extensions API and:
|
||||||
|
# - it blocks the calls if the extension has been disabled or uninstalled.
|
||||||
|
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
|
||||||
|
# - otherwise it has no effect
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if "path" not in scope:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
path_elements = scope["path"].split("/")
|
||||||
|
if len(path_elements) > 2:
|
||||||
|
_, path_name, path_type, *rest = path_elements
|
||||||
|
tail = "/".join(rest)
|
||||||
|
else:
|
||||||
|
_, path_name = path_elements
|
||||||
|
path_type = None
|
||||||
|
tail = ""
|
||||||
|
|
||||||
|
# block path for all users if the extension is disabled
|
||||||
|
if path_name in settings.lnbits_deactivated_extensions:
|
||||||
|
response = JSONResponse(
|
||||||
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
content={"detail": f"Extension '{path_name}' disabled"},
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# re-route API trafic if the extension has been upgraded
|
||||||
|
if path_type == "api":
|
||||||
|
upgraded_extensions = list(
|
||||||
|
filter(
|
||||||
|
lambda ext: ext.endswith(f"/{path_name}"),
|
||||||
|
settings.lnbits_upgraded_extensions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(upgraded_extensions) != 0:
|
||||||
|
upgrade_path = upgraded_extensions[0]
|
||||||
|
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||||
|
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
class CreateExtension(BaseModel):
|
class CreateExtension(BaseModel):
|
||||||
ext_id: str
|
ext_id: str
|
||||||
archive: str
|
archive: str
|
||||||
|
@@ -1,25 +1,18 @@
|
|||||||
# Borrowed from the excellent accent-starlette
|
|
||||||
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
|
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from starlette import templating
|
from jinja2 import BaseLoader, Environment, pass_context
|
||||||
from starlette.datastructures import QueryParams
|
from starlette.datastructures import QueryParams
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from starlette.templating import Jinja2Templates as SuperJinja2Templates
|
||||||
try:
|
|
||||||
import jinja2
|
|
||||||
except ImportError: # pragma: nocover
|
|
||||||
jinja2 = None # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class Jinja2Templates(templating.Jinja2Templates):
|
class Jinja2Templates(SuperJinja2Templates):
|
||||||
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231
|
def __init__(self, loader: BaseLoader) -> None:
|
||||||
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
|
super().__init__("")
|
||||||
self.env = self.get_environment(loader)
|
self.env = self.get_environment(loader)
|
||||||
|
|
||||||
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
|
def get_environment(self, loader: BaseLoader) -> Environment:
|
||||||
@jinja2.pass_context
|
@pass_context
|
||||||
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
|
||||||
request: Request = context["request"]
|
request: Request = context["request"]
|
||||||
return request.app.url_path_for(name, **path_params)
|
return request.app.url_path_for(name, **path_params)
|
||||||
@@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
|
|||||||
values.update(new)
|
values.update(new)
|
||||||
return QueryParams(**values)
|
return QueryParams(**values)
|
||||||
|
|
||||||
env = jinja2.Environment(loader=loader, autoescape=True)
|
env = Environment(loader=loader, autoescape=True)
|
||||||
env.globals["url_for"] = url_for
|
env.globals["url_for"] = url_for
|
||||||
env.globals["url_params_update"] = url_params_update
|
env.globals["url_params_update"] = url_params_update
|
||||||
return env
|
return env
|
||||||
|
@@ -24,6 +24,7 @@ def list_parse_fallback(v):
|
|||||||
|
|
||||||
|
|
||||||
class LNbitsSettings(BaseSettings):
|
class LNbitsSettings(BaseSettings):
|
||||||
|
@classmethod
|
||||||
def validate(cls, val):
|
def validate(cls, val):
|
||||||
if type(val) == str:
|
if type(val) == str:
|
||||||
val = val.split(",") if val else []
|
val = val.split(",") if val else []
|
||||||
@@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
|
|||||||
class LNbitsFundingSource(LNbitsSettings):
|
class LNbitsFundingSource(LNbitsSettings):
|
||||||
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
|
lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
|
||||||
lnbits_key: Optional[str] = Field(default=None)
|
lnbits_key: Optional[str] = Field(default=None)
|
||||||
|
lnbits_admin_key: Optional[str] = Field(default=None)
|
||||||
|
lnbits_invoice_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class ClicheFundingSource(LNbitsSettings):
|
class ClicheFundingSource(LNbitsSettings):
|
||||||
@@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
|
|||||||
lnpay_api_endpoint: Optional[str] = Field(default=None)
|
lnpay_api_endpoint: Optional[str] = Field(default=None)
|
||||||
lnpay_api_key: Optional[str] = Field(default=None)
|
lnpay_api_key: Optional[str] = Field(default=None)
|
||||||
lnpay_wallet_key: Optional[str] = Field(default=None)
|
lnpay_wallet_key: Optional[str] = Field(default=None)
|
||||||
|
lnpay_admin_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class OpenNodeFundingSource(LNbitsSettings):
|
class OpenNodeFundingSource(LNbitsSettings):
|
||||||
opennode_api_endpoint: Optional[str] = Field(default=None)
|
opennode_api_endpoint: Optional[str] = Field(default=None)
|
||||||
opennode_key: Optional[str] = Field(default=None)
|
opennode_key: Optional[str] = Field(default=None)
|
||||||
|
opennode_admin_key: Optional[str] = Field(default=None)
|
||||||
|
opennode_invoice_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class SparkFundingSource(LNbitsSettings):
|
class SparkFundingSource(LNbitsSettings):
|
||||||
@@ -208,8 +214,9 @@ class EditableSettings(
|
|||||||
"lnbits_admin_extensions",
|
"lnbits_admin_extensions",
|
||||||
pre=True,
|
pre=True,
|
||||||
)
|
)
|
||||||
|
@classmethod
|
||||||
def validate_editable_settings(cls, val):
|
def validate_editable_settings(cls, val):
|
||||||
return super().validate(cls, val)
|
return super().validate(val)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict):
|
def from_dict(cls, d: dict):
|
||||||
@@ -281,8 +288,9 @@ class ReadOnlySettings(
|
|||||||
"lnbits_allowed_funding_sources",
|
"lnbits_allowed_funding_sources",
|
||||||
pre=True,
|
pre=True,
|
||||||
)
|
)
|
||||||
|
@classmethod
|
||||||
def validate_readonly_settings(cls, val):
|
def validate_readonly_settings(cls, val):
|
||||||
return super().validate(cls, val)
|
return super().validate(val)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def readonly_fields(cls):
|
def readonly_fields(cls):
|
||||||
|
@@ -3,7 +3,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -42,7 +42,7 @@ class SseListenersDict(dict):
|
|||||||
A dict of sse listeners.
|
A dict of sse listeners.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name: str = None):
|
def __init__(self, name: Optional[str] = None):
|
||||||
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
|
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
@@ -65,7 +65,7 @@ class SseListenersDict(dict):
|
|||||||
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
|
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
|
||||||
|
|
||||||
|
|
||||||
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None):
|
def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
|
A method intended for extensions (and core/tasks.py) to call when they want to be notified about
|
||||||
new invoice payments incoming. Will emit all incoming payments.
|
new invoice payments incoming. Will emit all incoming payments.
|
||||||
@@ -164,7 +164,7 @@ async def check_pending_payments():
|
|||||||
async def perform_balance_checks():
|
async def perform_balance_checks():
|
||||||
while True:
|
while True:
|
||||||
for bc in await get_balance_checks():
|
for bc in await get_balance_checks():
|
||||||
redeem_lnurl_withdraw(bc.wallet, bc.url)
|
await redeem_lnurl_withdraw(bc.wallet, bc.url)
|
||||||
|
|
||||||
await asyncio.sleep(60 * 60 * 6) # every 6 hours
|
await asyncio.sleep(60 * 60 * 6) # every 6 hours
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user