fix pyright lnbits

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
Pavol Rusnak
2023-02-02 12:58:23 +00:00
committed by dni ⚡
parent 3855cf47f3
commit 02306148df
8 changed files with 266 additions and 116 deletions

View File

@@ -9,6 +9,7 @@ import secp256k1
from bech32 import CHARSET, bech32_decode, bech32_encode from bech32 import CHARSET, bech32_decode, bech32_encode
from ecdsa import SECP256k1, VerifyingKey from ecdsa import SECP256k1, VerifyingKey
from ecdsa.util import sigdecode_string from ecdsa.util import sigdecode_string
from loguru import logger
class Route(NamedTuple): class Route(NamedTuple):
@@ -30,6 +31,7 @@ class Invoice:
secret: Optional[str] = None secret: Optional[str] = None
route_hints: List[Route] = [] route_hints: List[Route] = []
min_final_cltv_expiry: int = 18 min_final_cltv_expiry: int = 18
checking_id: Optional[str] = None
def decode(pr: str) -> Invoice: def decode(pr: str) -> Invoice:
@@ -66,11 +68,13 @@ def decode(pr: str) -> Invoice:
invoice.amount_msat = _unshorten_amount(amountstr) invoice.amount_msat = _unshorten_amount(amountstr)
# pull out date # pull out date
invoice.date = data.read(35).uint date_bin = data.read(35)
assert date_bin
invoice.date = date_bin.uint
while data.pos != data.len: while data.pos != data.len:
tag, tagdata, data = _pull_tagged(data) tag, tagdata, data = _pull_tagged(data)
data_length = len(tagdata) / 5 data_length = len(tagdata or []) / 5
if tag == "d": if tag == "d":
invoice.description = _trim_to_bytes(tagdata).decode() invoice.description = _trim_to_bytes(tagdata).decode()
@@ -89,12 +93,22 @@ def decode(pr: str) -> Invoice:
elif tag == "r": elif tag == "r":
s = bitstring.ConstBitStream(tagdata) s = bitstring.ConstBitStream(tagdata)
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
pubkey = s.read(264)
assert pubkey
short_channel_id = s.read(64)
assert short_channel_id
base_fee_msat = s.read(32)
assert base_fee_msat
ppm_fee = s.read(32)
assert ppm_fee
cltv = s.read(16)
assert cltv
route = Route( route = Route(
pubkey=s.read(264).tobytes().hex(), pubkey=pubkey.tobytes().hex(),
short_channel_id=_readable_scid(s.read(64).intbe), short_channel_id=_readable_scid(short_channel_id.intbe),
base_fee_msat=s.read(32).intbe, base_fee_msat=base_fee_msat.intbe,
ppm_fee=s.read(32).intbe, ppm_fee=ppm_fee.intbe,
cltv=s.read(16).intbe, cltv=cltv.intbe,
) )
invoice.route_hints.append(route) invoice.route_hints.append(route)
@@ -160,6 +174,10 @@ def encode(options):
return lnencode(addr, options["privkey"]) return lnencode(addr, options["privkey"])
def encode_fallback(v, currency):
logger.error(f"hit bolt11.py encode_fallback with v: {v} and currency: {currency}")
def lnencode(addr, privkey): def lnencode(addr, privkey):
if addr.amount: if addr.amount:
amount = Decimal(str(addr.amount)) amount = Decimal(str(addr.amount))
@@ -244,7 +262,13 @@ def lnencode(addr, privkey):
class LnAddr: class LnAddr:
def __init__( def __init__(
self, paymenthash=None, amount=None, currency="bc", tags=None, date=None self,
paymenthash=None,
amount=None,
currency="bc",
tags=None,
date=None,
fallback=None,
): ):
self.date = int(time.time()) if not date else int(date) self.date = int(time.time()) if not date else int(date)
self.tags = [] if not tags else tags self.tags = [] if not tags else tags
@@ -252,6 +276,7 @@ class LnAddr:
self.paymenthash = paymenthash self.paymenthash = paymenthash
self.signature = None self.signature = None
self.pubkey = None self.pubkey = None
self.fallback = fallback
self.currency = currency self.currency = currency
self.amount = amount self.amount = amount
@@ -266,6 +291,7 @@ def shorten_amount(amount):
# Convert to pico initially # Convert to pico initially
amount = int(amount * 10**12) amount = int(amount * 10**12)
units = ["p", "n", "u", "m", ""] units = ["p", "n", "u", "m", ""]
unit = ""
for unit in units: for unit in units:
if amount % 1000 == 0: if amount % 1000 == 0:
amount //= 1000 amount //= 1000
@@ -304,14 +330,6 @@ def _pull_tagged(stream):
return (CHARSET[tag], stream.read(length * 5), stream) return (CHARSET[tag], stream.read(length * 5), stream)
def is_p2pkh(currency, prefix):
return prefix == base58_prefix_map[currency][0]
def is_p2sh(currency, prefix):
return prefix == base58_prefix_map[currency][1]
# Tagged field containing BitArray # Tagged field containing BitArray
def tagged(char, l): def tagged(char, l):
# Tagged fields need to be zero-padded to 5 bits. # Tagged fields need to be zero-padded to 5 bits.
@@ -359,5 +377,5 @@ def bitarray_to_u5(barr):
ret = [] ret = []
s = bitstring.ConstBitStream(barr) s = bitstring.ConstBitStream(barr)
while s.pos != s.len: while s.pos != s.len:
ret.append(s.read(5).uint) ret.append(s.read(5).uint) # type: ignore
return ret return ret

View File

@@ -41,6 +41,7 @@ async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them.""" """Creates the necessary databases if they don't exist already; or migrates them."""
async with core_db.connect() as conn: async with core_db.connect() as conn:
exists = False
if conn.type == SQLITE: if conn.type == SQLITE:
exists = await conn.fetchone( exists = await conn.fetchone(
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'" "SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"

View File

@@ -131,7 +131,7 @@ class Database(Compat):
else: else:
self.type = POSTGRES self.type = POSTGRES
import psycopg2 from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _): def _parse_timestamp(value, _):
if value is None: if value is None:
@@ -141,15 +141,15 @@ class Database(Compat):
f = "%Y-%m-%d %H:%M:%S" f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple()) return time.mktime(datetime.datetime.strptime(value, f).timetuple())
psycopg2.extensions.register_type( register_type(
psycopg2.extensions.new_type( new_type(
psycopg2.extensions.DECIMAL.values, DECIMAL.values,
"DEC2FLOAT", "DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None, lambda value, curs: float(value) if value is not None else None,
) )
) )
psycopg2.extensions.register_type( register_type(
psycopg2.extensions.new_type( new_type(
(1082, 1083, 1266), (1082, 1083, 1266),
"DATE2INT", "DATE2INT",
lambda value, curs: time.mktime(value.timetuple()) lambda value, curs: time.mktime(value.timetuple())
@@ -158,11 +158,7 @@ class Database(Compat):
) )
) )
psycopg2.extensions.register_type( register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
psycopg2.extensions.new_type(
(1184, 1114), "TIMESTAMP2INT", _parse_timestamp
)
)
else: else:
if os.path.isdir(settings.lnbits_data_folder): if os.path.isdir(settings.lnbits_data_folder):
self.path = os.path.join( self.path = os.path.join(

View File

@@ -1,14 +1,12 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Type from typing import Optional, Type
from fastapi import Security, status from fastapi import HTTPException, Request, Security, status
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security import APIKeyHeader, APIKeyQuery
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.types import UUID4 from pydantic.types import UUID4
from starlette.requests import Request
from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet from lnbits.core.models import User, Wallet
@@ -17,9 +15,13 @@ from lnbits.requestvars import g
from lnbits.settings import settings from lnbits.settings import settings
# TODO: fix type ignores
class KeyChecker(SecurityBase): class KeyChecker(SecurityBase):
def __init__( def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
): ):
self.scheme_name = scheme_name or self.__class__.__name__ self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error self.auto_error = auto_error
@@ -27,13 +29,13 @@ class KeyChecker(SecurityBase):
self._api_key = api_key self._api_key = api_key
if api_key: if api_key:
key = APIKey( key = APIKey(
**{"in": APIKeyIn.query}, **{"in": APIKeyIn.query}, # type: ignore
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - QUERY", description="Wallet API Key - QUERY",
) )
else: else:
key = APIKey( key = APIKey(
**{"in": APIKeyIn.header}, **{"in": APIKeyIn.header}, # type: ignore
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - HEADER", description="Wallet API Key - HEADER",
) )
@@ -73,7 +75,10 @@ class WalletInvoiceKeyChecker(KeyChecker):
""" """
def __init__( def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
): ):
super().__init__(scheme_name, auto_error, api_key) super().__init__(scheme_name, auto_error, api_key)
self._key_type = "invoice" self._key_type = "invoice"
@@ -89,7 +94,10 @@ class WalletAdminKeyChecker(KeyChecker):
""" """
def __init__( def __init__(
self, scheme_name: str = None, auto_error: bool = True, api_key: str = None self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
): ):
super().__init__(scheme_name, auto_error, api_key) super().__init__(scheme_name, auto_error, api_key)
self._key_type = "admin" self._key_type = "admin"

View File

@@ -3,20 +3,146 @@ import json
import os import os
import shutil import shutil
import sys import sys
import urllib.request
import zipfile import zipfile
from http import HTTPStatus from http import HTTPStatus
from pathlib import Path from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple from typing import Any, List, NamedTuple, Optional, Tuple
from urllib import request
import httpx import httpx
from fastapi.exceptions import HTTPException from fastapi import HTTPException
from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from lnbits.settings import settings from lnbits.settings import settings
class ExplicitRelease(BaseModel):
id: str
name: str
version: str
archive: str
hash: str
dependencies: List[str] = []
icon: Optional[str]
short_description: Optional[str]
html_url: Optional[str]
details: Optional[str]
info_notification: Optional[str]
critical_notification: Optional[str]
class GitHubRelease(BaseModel):
id: str
organisation: str
repository: str
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
class GitHubRepoRelease(BaseModel):
name: str
tag_name: str
zipball_url: str
html_url: str
class GitHubRepo(BaseModel):
stargazers_count: str
html_url: str
default_branch: str
class ExtensionConfig(BaseModel):
name: str
short_description: str
tile: str = ""
def download_url(url, save_path):
with request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:
out_file.write(dl_file.read())
def file_hash(filename):
h = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with open(filename, "rb", buffering=0) as f:
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
async def fetch_github_repo_info(
org: str, repository: str
) -> Tuple[GitHubRepo, GitHubRepoRelease, ExtensionConfig]:
repo_url = f"https://api.github.com/repos/{org}/{repository}"
error_msg = "Cannot fetch extension repo"
repo = await gihub_api_get(repo_url, error_msg)
github_repo = GitHubRepo.parse_obj(repo)
lates_release_url = (
f"https://api.github.com/repos/{org}/{repository}/releases/latest"
)
error_msg = "Cannot fetch extension releases"
latest_release: Any = await gihub_api_get(lates_release_url, error_msg)
config_url = f"https://raw.githubusercontent.com/{org}/{repository}/{github_repo.default_branch}/config.json"
error_msg = "Cannot fetch config for extension"
config = await gihub_api_get(config_url, error_msg)
return (
github_repo,
GitHubRepoRelease.parse_obj(latest_release),
ExtensionConfig.parse_obj(config),
)
async def fetch_manifest(url) -> Manifest:
error_msg = "Cannot fetch extensions manifest"
manifest = await gihub_api_get(url, error_msg)
return Manifest.parse_obj(manifest)
async def fetch_github_releases(org: str, repo: str) -> List[GitHubRepoRelease]:
releases_url = f"https://api.github.com/repos/{org}/{repo}/releases"
error_msg = "Cannot fetch extension releases"
releases = await gihub_api_get(releases_url, error_msg)
return [GitHubRepoRelease.parse_obj(r) for r in releases]
async def gihub_api_get(url: str, error_msg: Optional[str]) -> Any:
async with httpx.AsyncClient() as client:
headers = (
{"Authorization": "Bearer " + settings.lnbits_ext_github_token}
if settings.lnbits_ext_github_token
else None
)
resp = await client.get(
url,
headers=headers,
)
if resp.status_code != 200:
logger.warning(f"{error_msg} ({url}): {resp.text}")
resp.raise_for_status()
return resp.json()
def icon_to_github_url(source_repo: str, path: Optional[str]) -> str:
if not path:
return ""
_, _, *rest = path.split("/")
tail = "/".join(rest)
return f"https://github.com/{source_repo}/raw/main/{tail}"
class Extension(NamedTuple): class Extension(NamedTuple):
code: str code: str
is_valid: bool is_valid: bool
@@ -97,12 +223,12 @@ class ExtensionRelease(BaseModel):
version: str version: str
archive: str archive: str
source_repo: str source_repo: str
is_github_release = False is_github_release: bool = False
hash: Optional[str] hash: Optional[str] = None
html_url: Optional[str] html_url: Optional[str] = None
description: Optional[str] description: Optional[str] = None
details_html: Optional[str] = None details_html: Optional[str] = None
icon: Optional[str] icon: Optional[str] = None
@classmethod @classmethod
def from_github_release( def from_github_release(
@@ -132,52 +258,6 @@ class ExtensionRelease(BaseModel):
return [] return []
class ExplicitRelease(BaseModel):
id: str
name: str
version: str
archive: str
hash: str
dependencies: List[str] = []
icon: Optional[str]
short_description: Optional[str]
html_url: Optional[str]
details: Optional[str]
info_notification: Optional[str]
critical_notification: Optional[str]
class GitHubRelease(BaseModel):
id: str
organisation: str
repository: str
class Manifest(BaseModel):
featured: List[str] = []
extensions: List["ExplicitRelease"] = []
repos: List["GitHubRelease"] = []
class GitHubRepoRelease(BaseModel):
name: str
tag_name: str
zipball_url: str
html_url: str
class GitHubRepo(BaseModel):
stargazers_count: str
html_url: str
default_branch: str
class ExtensionConfig(BaseModel):
name: str
short_description: str
tile: str = ""
class InstallableExtension(BaseModel): class InstallableExtension(BaseModel):
id: str id: str
name: str name: str
@@ -187,8 +267,9 @@ class InstallableExtension(BaseModel):
is_admin_only: bool = False is_admin_only: bool = False
stars: int = 0 stars: int = 0
featured = False featured = False
latest_release: Optional[ExtensionRelease] latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] installed_release: Optional[ExtensionRelease] = None
archive: Optional[str] = None
@property @property
def hash(self) -> str: def hash(self) -> str:
@@ -234,6 +315,7 @@ class InstallableExtension(BaseModel):
if ext_zip_file.is_file(): if ext_zip_file.is_file():
os.remove(ext_zip_file) os.remove(ext_zip_file)
try: try:
assert self.installed_release
download_url(self.installed_release.archive, ext_zip_file) download_url(self.installed_release.archive, ext_zip_file)
except Exception as ex: except Exception as ex:
logger.warning(ex) logger.warning(ex)
@@ -334,8 +416,7 @@ class InstallableExtension(BaseModel):
id=github_release.id, id=github_release.id,
name=config.name, name=config.name,
short_description=config.short_description, short_description=config.short_description,
version="0", stars=int(repo.stargazers_count),
stars=repo.stargazers_count,
icon=icon_to_github_url( icon=icon_to_github_url(
f"{github_release.organisation}/{github_release.repository}", f"{github_release.organisation}/{github_release.repository}",
config.tile, config.tile,
@@ -354,7 +435,6 @@ class InstallableExtension(BaseModel):
id=e.id, id=e.id,
name=e.name, name=e.name,
archive=e.archive, archive=e.archive,
hash=e.hash,
short_description=e.short_description, short_description=e.short_description,
icon=e.icon, icon=e.icon,
dependencies=e.dependencies, dependencies=e.dependencies,
@@ -443,6 +523,52 @@ class InstallableExtension(BaseModel):
return selected_release[0] if len(selected_release) != 0 else None return selected_release[0] if len(selected_release) != 0 else None
class InstalledExtensionMiddleware:
# This middleware class intercepts calls made to the extensions API and:
# - it blocks the calls if the extension has been disabled or uninstalled.
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
# - otherwise it has no effect
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" not in scope:
await self.app(scope, receive, send)
return
path_elements = scope["path"].split("/")
if len(path_elements) > 2:
_, path_name, path_type, *rest = path_elements
tail = "/".join(rest)
else:
_, path_name = path_elements
path_type = None
tail = ""
# block path for all users if the extension is disabled
if path_name in settings.lnbits_deactivated_extensions:
response = JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"detail": f"Extension '{path_name}' disabled"},
)
await response(scope, receive, send)
return
# re-route API trafic if the extension has been upgraded
if path_type == "api":
upgraded_extensions = list(
filter(
lambda ext: ext.endswith(f"/{path_name}"),
settings.lnbits_upgraded_extensions,
)
)
if len(upgraded_extensions) != 0:
upgrade_path = upgraded_extensions[0]
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
await self.app(scope, receive, send)
class CreateExtension(BaseModel): class CreateExtension(BaseModel):
ext_id: str ext_id: str
archive: str archive: str

View File

@@ -1,25 +1,18 @@
# Borrowed from the excellent accent-starlette
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
import typing import typing
from starlette import templating from jinja2 import BaseLoader, Environment, pass_context
from starlette.datastructures import QueryParams from starlette.datastructures import QueryParams
from starlette.requests import Request from starlette.requests import Request
from starlette.templating import Jinja2Templates as SuperJinja2Templates
try:
import jinja2
except ImportError: # pragma: nocover
jinja2 = None # type: ignore
class Jinja2Templates(templating.Jinja2Templates): class Jinja2Templates(SuperJinja2Templates):
def __init__(self, loader: jinja2.BaseLoader) -> None: # pylint: disable=W0231 def __init__(self, loader: BaseLoader) -> None:
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" super().__init__("")
self.env = self.get_environment(loader) self.env = self.get_environment(loader)
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": def get_environment(self, loader: BaseLoader) -> Environment:
@jinja2.pass_context @pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> str: def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request: Request = context["request"] request: Request = context["request"]
return request.app.url_path_for(name, **path_params) return request.app.url_path_for(name, **path_params)
@@ -29,7 +22,7 @@ class Jinja2Templates(templating.Jinja2Templates):
values.update(new) values.update(new)
return QueryParams(**values) return QueryParams(**values)
env = jinja2.Environment(loader=loader, autoescape=True) env = Environment(loader=loader, autoescape=True)
env.globals["url_for"] = url_for env.globals["url_for"] = url_for
env.globals["url_params_update"] = url_params_update env.globals["url_params_update"] = url_params_update
return env return env

View File

@@ -24,6 +24,7 @@ def list_parse_fallback(v):
class LNbitsSettings(BaseSettings): class LNbitsSettings(BaseSettings):
@classmethod
def validate(cls, val): def validate(cls, val):
if type(val) == str: if type(val) == str:
val = val.split(",") if val else [] val = val.split(",") if val else []
@@ -103,6 +104,8 @@ class FakeWalletFundingSource(LNbitsSettings):
class LNbitsFundingSource(LNbitsSettings): class LNbitsFundingSource(LNbitsSettings):
lnbits_endpoint: str = Field(default="https://legend.lnbits.com") lnbits_endpoint: str = Field(default="https://legend.lnbits.com")
lnbits_key: Optional[str] = Field(default=None) lnbits_key: Optional[str] = Field(default=None)
lnbits_admin_key: Optional[str] = Field(default=None)
lnbits_invoice_key: Optional[str] = Field(default=None)
class ClicheFundingSource(LNbitsSettings): class ClicheFundingSource(LNbitsSettings):
@@ -145,11 +148,14 @@ class LnPayFundingSource(LNbitsSettings):
lnpay_api_endpoint: Optional[str] = Field(default=None) lnpay_api_endpoint: Optional[str] = Field(default=None)
lnpay_api_key: Optional[str] = Field(default=None) lnpay_api_key: Optional[str] = Field(default=None)
lnpay_wallet_key: Optional[str] = Field(default=None) lnpay_wallet_key: Optional[str] = Field(default=None)
lnpay_admin_key: Optional[str] = Field(default=None)
class OpenNodeFundingSource(LNbitsSettings): class OpenNodeFundingSource(LNbitsSettings):
opennode_api_endpoint: Optional[str] = Field(default=None) opennode_api_endpoint: Optional[str] = Field(default=None)
opennode_key: Optional[str] = Field(default=None) opennode_key: Optional[str] = Field(default=None)
opennode_admin_key: Optional[str] = Field(default=None)
opennode_invoice_key: Optional[str] = Field(default=None)
class SparkFundingSource(LNbitsSettings): class SparkFundingSource(LNbitsSettings):
@@ -208,8 +214,9 @@ class EditableSettings(
"lnbits_admin_extensions", "lnbits_admin_extensions",
pre=True, pre=True,
) )
@classmethod
def validate_editable_settings(cls, val): def validate_editable_settings(cls, val):
return super().validate(cls, val) return super().validate(val)
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
@@ -281,8 +288,9 @@ class ReadOnlySettings(
"lnbits_allowed_funding_sources", "lnbits_allowed_funding_sources",
pre=True, pre=True,
) )
@classmethod
def validate_readonly_settings(cls, val): def validate_readonly_settings(cls, val):
return super().validate(cls, val) return super().validate(val)
@classmethod @classmethod
def readonly_fields(cls): def readonly_fields(cls):

View File

@@ -3,7 +3,7 @@ import time
import traceback import traceback
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict from typing import Dict, Optional
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from loguru import logger from loguru import logger
@@ -42,7 +42,7 @@ class SseListenersDict(dict):
A dict of sse listeners. A dict of sse listeners.
""" """
def __init__(self, name: str = None): def __init__(self, name: Optional[str] = None):
self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}" self.name = name or f"sse_listener_{str(uuid.uuid4())[:8]}"
def __setitem__(self, key, value): def __setitem__(self, key, value):
@@ -65,7 +65,7 @@ class SseListenersDict(dict):
invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners") invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict("invoice_listeners")
def register_invoice_listener(send_chan: asyncio.Queue, name: str = None): def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = None):
""" """
A method intended for extensions (and core/tasks.py) to call when they want to be notified about A method intended for extensions (and core/tasks.py) to call when they want to be notified about
new invoice payments incoming. Will emit all incoming payments. new invoice payments incoming. Will emit all incoming payments.
@@ -164,7 +164,7 @@ async def check_pending_payments():
async def perform_balance_checks(): async def perform_balance_checks():
while True: while True:
for bc in await get_balance_checks(): for bc in await get_balance_checks():
redeem_lnurl_withdraw(bc.wallet, bc.url) await redeem_lnurl_withdraw(bc.wallet, bc.url)
await asyncio.sleep(60 * 60 * 6) # every 6 hours await asyncio.sleep(60 * 60 * 6) # every 6 hours