mirror of
https://github.com/lnbits/lnbits.git
synced 2025-04-02 17:08:24 +02:00
escape user
This commit is contained in:
parent
90473b4723
commit
4cba21b962
lnbits
tests
@ -332,55 +332,23 @@ async def create_wallet(
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Wallet:
|
||||
wallet_id = uuid4().hex
|
||||
now = int(time())
|
||||
now_ph = db.timestamp_placeholder("now")
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
INSERT INTO wallets (id, name, "user", adminkey, inkey, created_at, updated_at)
|
||||
VALUES (:wallet, :name, :user, :adminkey, :inkey, {now_ph}, {now_ph})
|
||||
""",
|
||||
{
|
||||
"wallet": wallet_id,
|
||||
"name": wallet_name or settings.lnbits_default_wallet_name,
|
||||
"user": user_id,
|
||||
"adminkey": uuid4().hex,
|
||||
"inkey": uuid4().hex,
|
||||
"now": now,
|
||||
},
|
||||
wallet = Wallet(
|
||||
id=wallet_id,
|
||||
name=wallet_name or settings.lnbits_default_wallet_name,
|
||||
user=user_id,
|
||||
adminkey=uuid4().hex,
|
||||
inkey=uuid4().hex,
|
||||
)
|
||||
|
||||
new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
assert new_wallet, "Newly created wallet couldn't be retrieved"
|
||||
|
||||
return new_wallet
|
||||
await (conn or db).update("wallets", wallet)
|
||||
return wallet
|
||||
|
||||
|
||||
async def update_wallet(
|
||||
wallet_id: str,
|
||||
name: Optional[str] = None,
|
||||
currency: Optional[str] = None,
|
||||
wallet: Wallet,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[Wallet]:
|
||||
set_clause = []
|
||||
set_clause.append(f"updated_at = {db.timestamp_placeholder('now')}")
|
||||
values: dict = {
|
||||
"wallet": wallet_id,
|
||||
"now": int(time()),
|
||||
}
|
||||
if name:
|
||||
set_clause.append("name = :name")
|
||||
values["name"] = name
|
||||
if currency is not None:
|
||||
set_clause.append("currency = :currency")
|
||||
values["currency"] = currency
|
||||
await (conn or db).execute(
|
||||
f"""
|
||||
UPDATE wallets SET {', '.join(set_clause)} WHERE id = :wallet
|
||||
""",
|
||||
values,
|
||||
)
|
||||
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
assert wallet, "updated created wallet couldn't be retrieved"
|
||||
wallet.updated_at = datetime.now(timezone.utc)
|
||||
await (conn or db).update("wallets", wallet)
|
||||
return wallet
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
@ -41,11 +41,15 @@ class Wallet(BaseModel):
|
||||
name: str
|
||||
adminkey: str
|
||||
inkey: str
|
||||
currency: Optional[str]
|
||||
deleted: bool = False
|
||||
created_at: Optional[int] = None
|
||||
updated_at: Optional[int] = None
|
||||
balance_msat: int = 0
|
||||
created_at: datetime = datetime.now(timezone.utc)
|
||||
updated_at: datetime = datetime.now(timezone.utc)
|
||||
currency: Optional[str] = None
|
||||
|
||||
# @property
|
||||
# def balance_msat(self) -> int:
|
||||
# return self.balance_msat // 1000
|
||||
|
||||
@property
|
||||
def balance(self) -> int:
|
||||
@ -73,11 +77,6 @@ class Wallet(BaseModel):
|
||||
linking_key, curve=SECP256k1, hashfunc=hashlib.sha256
|
||||
)
|
||||
|
||||
async def get_payment(self, payment_hash: str) -> Optional[Payment]:
|
||||
from .crud import get_standalone_payment
|
||||
|
||||
return await get_standalone_payment(payment_hash)
|
||||
|
||||
|
||||
class KeyType(Enum):
|
||||
admin = 0
|
||||
@ -115,8 +114,8 @@ class Account(BaseModel):
|
||||
pubkey: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
extra: UserExtra = UserExtra()
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
created_at: datetime = datetime.now(timezone.utc)
|
||||
updated_at: datetime = datetime.now(timezone.utc)
|
||||
|
||||
@property
|
||||
def is_super_user(self) -> bool:
|
||||
|
@ -1,9 +1,11 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
)
|
||||
|
||||
from lnbits.core.models import (
|
||||
@ -20,6 +22,7 @@ from lnbits.decorators import (
|
||||
from ..crud import (
|
||||
create_wallet,
|
||||
delete_wallet,
|
||||
get_wallet,
|
||||
update_wallet,
|
||||
)
|
||||
|
||||
@ -27,35 +30,45 @@ wallet_router = APIRouter(prefix="/api/v1/wallet", tags=["Wallet"])
|
||||
|
||||
|
||||
@wallet_router.get("")
|
||||
async def api_wallet(wallet: WalletTypeInfo = Depends(require_invoice_key)):
|
||||
async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)):
|
||||
res = {
|
||||
"name": wallet.wallet.name,
|
||||
"balance": wallet.wallet.balance_msat,
|
||||
"name": key_info.wallet.name,
|
||||
"balance": key_info.wallet.balance_msat,
|
||||
}
|
||||
if wallet.key_type == KeyType.admin:
|
||||
res["id"] = wallet.wallet.id
|
||||
if key_info.key_type == KeyType.admin:
|
||||
res["id"] = key_info.wallet.id
|
||||
return res
|
||||
|
||||
|
||||
@wallet_router.put("/{new_name}")
|
||||
async def api_update_wallet_name(
|
||||
new_name: str, wallet: WalletTypeInfo = Depends(require_admin_key)
|
||||
new_name: str, key_info: WalletTypeInfo = Depends(require_admin_key)
|
||||
):
|
||||
await update_wallet(wallet.wallet.id, new_name)
|
||||
wallet = await get_wallet(key_info.wallet.id)
|
||||
if not wallet:
|
||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||
wallet.name = new_name
|
||||
await update_wallet(wallet)
|
||||
return {
|
||||
"id": wallet.wallet.id,
|
||||
"name": wallet.wallet.name,
|
||||
"balance": wallet.wallet.balance_msat,
|
||||
"id": wallet.id,
|
||||
"name": wallet.name,
|
||||
"balance": wallet.balance_msat,
|
||||
}
|
||||
|
||||
|
||||
@wallet_router.patch("", response_model=Wallet)
|
||||
@wallet_router.patch("")
|
||||
async def api_update_wallet(
|
||||
name: Optional[str] = Body(None),
|
||||
currency: Optional[str] = Body(None),
|
||||
wallet: WalletTypeInfo = Depends(require_admin_key),
|
||||
):
|
||||
return await update_wallet(wallet.wallet.id, name, currency)
|
||||
key_info: WalletTypeInfo = Depends(require_admin_key),
|
||||
) -> Wallet:
|
||||
wallet = await get_wallet(key_info.wallet.id)
|
||||
if not wallet:
|
||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||
wallet.name = name or wallet.name
|
||||
wallet.currency = currency or wallet.currency
|
||||
await update_wallet(wallet)
|
||||
return wallet
|
||||
|
||||
|
||||
@wallet_router.delete("")
|
||||
|
26
lnbits/db.py
26
lnbits/db.py
@ -283,21 +283,19 @@ class Database(Compat):
|
||||
|
||||
@event.listens_for(self.engine.sync_engine, "connect")
|
||||
def register_custom_types(dbapi_connection, *_):
|
||||
def _parse_timestamp(value):
|
||||
def _parse_date(value) -> datetime.datetime:
|
||||
if value is None:
|
||||
return None
|
||||
value = "1970-01-01 00:00:00"
|
||||
f = "%Y-%m-%d %H:%M:%S.%f"
|
||||
if "." not in value:
|
||||
f = "%Y-%m-%d %H:%M:%S"
|
||||
return int(
|
||||
time.mktime(datetime.datetime.strptime(value, f).timetuple())
|
||||
)
|
||||
return datetime.datetime.strptime(value, f)
|
||||
|
||||
dbapi_connection.run_async(
|
||||
lambda connection: connection.set_type_codec(
|
||||
"TIMESTAMP",
|
||||
encoder=datetime.datetime,
|
||||
decoder=_parse_timestamp,
|
||||
decoder=_parse_date,
|
||||
schema="pg_catalog",
|
||||
)
|
||||
)
|
||||
@ -574,7 +572,8 @@ def insert_query(table_name: str, model: BaseModel) -> str:
|
||||
placeholders = []
|
||||
for field in model.dict().keys():
|
||||
placeholders.append(get_placeholder(model, field))
|
||||
fields = ", ".join(model.dict().keys())
|
||||
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
|
||||
fields = ", ".join([f'"{key}"' for key in model.dict().keys()])
|
||||
values = ", ".join(placeholders)
|
||||
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
|
||||
|
||||
@ -589,7 +588,8 @@ def update_query(table_name: str, model: BaseModel, where: str = "id = :id") ->
|
||||
fields = []
|
||||
for field in model.dict().keys():
|
||||
placeholder = get_placeholder(model, field)
|
||||
fields.append(f"{field} = {placeholder}")
|
||||
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
|
||||
fields.append(f'"{field}" = {placeholder}')
|
||||
query = ", ".join(fields)
|
||||
return f"UPDATE {table_name} SET {query} WHERE {where}"
|
||||
|
||||
@ -600,12 +600,12 @@ def model_to_dict(model: BaseModel) -> dict:
|
||||
private fields starting with _ are ignored
|
||||
:param model: Pydantic model
|
||||
"""
|
||||
_dict = {}
|
||||
_dict: dict = {}
|
||||
for key, value in model.dict().items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
type_ = model.__fields__[key].type_
|
||||
if type_ is datetime.datetime:
|
||||
if isinstance(value, datetime.datetime):
|
||||
_dict[key] = value.timestamp()
|
||||
continue
|
||||
if type(type_) is type(BaseModel):
|
||||
@ -643,9 +643,9 @@ def dict_to_model(_row: dict, model: type[TModel]) -> TModel:
|
||||
logger.warning(f"Converting {key} to model `{model}`.")
|
||||
continue
|
||||
type_ = model.__fields__[key].type_
|
||||
if issubclass(type_, datetime.datetime):
|
||||
_dict[key] = datetime.datetime.fromtimestamp(value)
|
||||
continue
|
||||
# if issubclass(type_, datetime.datetime):
|
||||
# _dict[key] = datetime.datetime.fromtimestamp(value)
|
||||
# continue
|
||||
if issubclass(type_, bool):
|
||||
_dict[key] = bool(value)
|
||||
continue
|
||||
|
@ -126,6 +126,7 @@ async def from_user():
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def from_wallet(from_user):
|
||||
user = from_user
|
||||
|
||||
wallet = await create_wallet(user_id=user.id, wallet_name="test_wallet_from")
|
||||
await update_wallet_balance(
|
||||
wallet_id=wallet.id,
|
||||
|
@ -26,6 +26,7 @@ class DbTestModel2(BaseModel):
|
||||
|
||||
class DbTestModel3(BaseModel):
|
||||
id: int
|
||||
user: str
|
||||
child: DbTestModel2
|
||||
active: bool = False
|
||||
|
||||
|
@ -12,6 +12,7 @@ from tests.helpers import DbTestModel, DbTestModel2, DbTestModel3
|
||||
|
||||
test_data = DbTestModel3(
|
||||
id=1,
|
||||
user="userid",
|
||||
child=DbTestModel2(
|
||||
id=2,
|
||||
label="test",
|
||||
@ -26,8 +27,8 @@ test_data = DbTestModel3(
|
||||
async def test_helpers_insert_query():
|
||||
q = insert_query("test_helpers_query", test_data)
|
||||
assert q == (
|
||||
"INSERT INTO test_helpers_query (id, child, active) "
|
||||
"VALUES (:id, :child, :active)"
|
||||
"""INSERT INTO test_helpers_query ("id", "user", "child", "active") """
|
||||
"VALUES (:id, :user, :child, :active)"
|
||||
)
|
||||
|
||||
|
||||
@ -35,8 +36,8 @@ async def test_helpers_insert_query():
|
||||
async def test_helpers_update_query():
|
||||
q = update_query("test_helpers_query", test_data)
|
||||
assert q == (
|
||||
"UPDATE test_helpers_query "
|
||||
"SET id = :id, child = :child, active = :active WHERE id = :id"
|
||||
"""UPDATE test_helpers_query SET "id" = :id, "user" = """
|
||||
""":user, "child" = :child, "active" = :active WHERE id = :id"""
|
||||
)
|
||||
|
||||
|
||||
@ -48,7 +49,7 @@ child_json = json.dumps(
|
||||
"child": {"id": 3, "name": "myname", "value": "myvalue"},
|
||||
}
|
||||
)
|
||||
test_dict = {"id": 1, "child": child_json, "active": True}
|
||||
test_dict = {"id": 1, "user": "userid", "child": child_json, "active": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -57,6 +58,7 @@ async def test_helpers_model_to_dict():
|
||||
assert d.get("id") == test_data.id
|
||||
assert d.get("active") == test_data.active
|
||||
assert d.get("child") == child_json
|
||||
assert d.get("user") == test_data.user
|
||||
assert d == test_dict
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user