escape user

This commit is contained in:
dni ⚡ 2024-10-03 07:37:38 +02:00
parent 90473b4723
commit 4cba21b962
No known key found for this signature in database
GPG Key ID: D1F416F29AD26E87
7 changed files with 70 additions and 86 deletions

@ -332,55 +332,23 @@ async def create_wallet(
conn: Optional[Connection] = None,
) -> Wallet:
wallet_id = uuid4().hex
now = int(time())
now_ph = db.timestamp_placeholder("now")
await (conn or db).execute(
f"""
INSERT INTO wallets (id, name, "user", adminkey, inkey, created_at, updated_at)
VALUES (:wallet, :name, :user, :adminkey, :inkey, {now_ph}, {now_ph})
""",
{
"wallet": wallet_id,
"name": wallet_name or settings.lnbits_default_wallet_name,
"user": user_id,
"adminkey": uuid4().hex,
"inkey": uuid4().hex,
"now": now,
},
wallet = Wallet(
id=wallet_id,
name=wallet_name or settings.lnbits_default_wallet_name,
user=user_id,
adminkey=uuid4().hex,
inkey=uuid4().hex,
)
new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert new_wallet, "Newly created wallet couldn't be retrieved"
return new_wallet
await (conn or db).update("wallets", wallet)
return wallet
async def update_wallet(
wallet_id: str,
name: Optional[str] = None,
currency: Optional[str] = None,
wallet: Wallet,
conn: Optional[Connection] = None,
) -> Optional[Wallet]:
set_clause = []
set_clause.append(f"updated_at = {db.timestamp_placeholder('now')}")
values: dict = {
"wallet": wallet_id,
"now": int(time()),
}
if name:
set_clause.append("name = :name")
values["name"] = name
if currency is not None:
set_clause.append("currency = :currency")
values["currency"] = currency
await (conn or db).execute(
f"""
UPDATE wallets SET {', '.join(set_clause)} WHERE id = :wallet
""",
values,
)
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert wallet, "updated created wallet couldn't be retrieved"
wallet.updated_at = datetime.now(timezone.utc)
await (conn or db).update("wallets", wallet)
return wallet

@ -4,7 +4,7 @@ import hashlib
import hmac
import time
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from typing import Callable, Optional
@ -41,11 +41,15 @@ class Wallet(BaseModel):
name: str
adminkey: str
inkey: str
currency: Optional[str]
deleted: bool = False
created_at: Optional[int] = None
updated_at: Optional[int] = None
balance_msat: int = 0
created_at: datetime = datetime.now(timezone.utc)
updated_at: datetime = datetime.now(timezone.utc)
currency: Optional[str] = None
# @property
# def balance_msat(self) -> int:
# return self.balance_msat // 1000
@property
def balance(self) -> int:
@ -73,11 +77,6 @@ class Wallet(BaseModel):
linking_key, curve=SECP256k1, hashfunc=hashlib.sha256
)
async def get_payment(self, payment_hash: str) -> Optional[Payment]:
from .crud import get_standalone_payment
return await get_standalone_payment(payment_hash)
class KeyType(Enum):
admin = 0
@ -115,8 +114,8 @@ class Account(BaseModel):
pubkey: Optional[str] = None
email: Optional[str] = None
extra: UserExtra = UserExtra()
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
created_at: datetime = datetime.now(timezone.utc)
updated_at: datetime = datetime.now(timezone.utc)
@property
def is_super_user(self) -> bool:

@ -1,9 +1,11 @@
from http import HTTPStatus
from typing import Optional
from fastapi import (
APIRouter,
Body,
Depends,
HTTPException,
)
from lnbits.core.models import (
@ -20,6 +22,7 @@ from lnbits.decorators import (
from ..crud import (
create_wallet,
delete_wallet,
get_wallet,
update_wallet,
)
@ -27,35 +30,45 @@ wallet_router = APIRouter(prefix="/api/v1/wallet", tags=["Wallet"])
@wallet_router.get("")
async def api_wallet(wallet: WalletTypeInfo = Depends(require_invoice_key)):
async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)):
res = {
"name": wallet.wallet.name,
"balance": wallet.wallet.balance_msat,
"name": key_info.wallet.name,
"balance": key_info.wallet.balance_msat,
}
if wallet.key_type == KeyType.admin:
res["id"] = wallet.wallet.id
if key_info.key_type == KeyType.admin:
res["id"] = key_info.wallet.id
return res
@wallet_router.put("/{new_name}")
async def api_update_wallet_name(
new_name: str, wallet: WalletTypeInfo = Depends(require_admin_key)
new_name: str, key_info: WalletTypeInfo = Depends(require_admin_key)
):
await update_wallet(wallet.wallet.id, new_name)
wallet = await get_wallet(key_info.wallet.id)
if not wallet:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
wallet.name = new_name
await update_wallet(wallet)
return {
"id": wallet.wallet.id,
"name": wallet.wallet.name,
"balance": wallet.wallet.balance_msat,
"id": wallet.id,
"name": wallet.name,
"balance": wallet.balance_msat,
}
@wallet_router.patch("", response_model=Wallet)
@wallet_router.patch("")
async def api_update_wallet(
name: Optional[str] = Body(None),
currency: Optional[str] = Body(None),
wallet: WalletTypeInfo = Depends(require_admin_key),
):
return await update_wallet(wallet.wallet.id, name, currency)
key_info: WalletTypeInfo = Depends(require_admin_key),
) -> Wallet:
wallet = await get_wallet(key_info.wallet.id)
if not wallet:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
wallet.name = name or wallet.name
wallet.currency = currency or wallet.currency
await update_wallet(wallet)
return wallet
@wallet_router.delete("")

@ -283,21 +283,19 @@ class Database(Compat):
@event.listens_for(self.engine.sync_engine, "connect")
def register_custom_types(dbapi_connection, *_):
def _parse_timestamp(value):
def _parse_date(value) -> datetime.datetime:
if value is None:
return None
value = "1970-01-01 00:00:00"
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return int(
time.mktime(datetime.datetime.strptime(value, f).timetuple())
)
return datetime.datetime.strptime(value, f)
dbapi_connection.run_async(
lambda connection: connection.set_type_codec(
"TIMESTAMP",
encoder=datetime.datetime,
decoder=_parse_timestamp,
decoder=_parse_date,
schema="pg_catalog",
)
)
@ -574,7 +572,8 @@ def insert_query(table_name: str, model: BaseModel) -> str:
placeholders = []
for field in model.dict().keys():
placeholders.append(get_placeholder(model, field))
fields = ", ".join(model.dict().keys())
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
fields = ", ".join([f'"{key}"' for key in model.dict().keys()])
values = ", ".join(placeholders)
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
@ -589,7 +588,8 @@ def update_query(table_name: str, model: BaseModel, where: str = "id = :id") ->
fields = []
for field in model.dict().keys():
placeholder = get_placeholder(model, field)
fields.append(f"{field} = {placeholder}")
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
fields.append(f'"{field}" = {placeholder}')
query = ", ".join(fields)
return f"UPDATE {table_name} SET {query} WHERE {where}"
@ -600,12 +600,12 @@ def model_to_dict(model: BaseModel) -> dict:
private fields starting with _ are ignored
:param model: Pydantic model
"""
_dict = {}
_dict: dict = {}
for key, value in model.dict().items():
if key.startswith("_"):
continue
type_ = model.__fields__[key].type_
if type_ is datetime.datetime:
if isinstance(value, datetime.datetime):
_dict[key] = value.timestamp()
continue
if type(type_) is type(BaseModel):
@ -643,9 +643,9 @@ def dict_to_model(_row: dict, model: type[TModel]) -> TModel:
logger.warning(f"Converting {key} to model `{model}`.")
continue
type_ = model.__fields__[key].type_
if issubclass(type_, datetime.datetime):
_dict[key] = datetime.datetime.fromtimestamp(value)
continue
# if issubclass(type_, datetime.datetime):
# _dict[key] = datetime.datetime.fromtimestamp(value)
# continue
if issubclass(type_, bool):
_dict[key] = bool(value)
continue

@ -126,6 +126,7 @@ async def from_user():
@pytest_asyncio.fixture(scope="session")
async def from_wallet(from_user):
user = from_user
wallet = await create_wallet(user_id=user.id, wallet_name="test_wallet_from")
await update_wallet_balance(
wallet_id=wallet.id,

@ -26,6 +26,7 @@ class DbTestModel2(BaseModel):
class DbTestModel3(BaseModel):
id: int
user: str
child: DbTestModel2
active: bool = False

@ -12,6 +12,7 @@ from tests.helpers import DbTestModel, DbTestModel2, DbTestModel3
test_data = DbTestModel3(
id=1,
user="userid",
child=DbTestModel2(
id=2,
label="test",
@ -26,8 +27,8 @@ test_data = DbTestModel3(
async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test_data)
assert q == (
"INSERT INTO test_helpers_query (id, child, active) "
"VALUES (:id, :child, :active)"
"""INSERT INTO test_helpers_query ("id", "user", "child", "active") """
"VALUES (:id, :user, :child, :active)"
)
@ -35,8 +36,8 @@ async def test_helpers_insert_query():
async def test_helpers_update_query():
q = update_query("test_helpers_query", test_data)
assert q == (
"UPDATE test_helpers_query "
"SET id = :id, child = :child, active = :active WHERE id = :id"
"""UPDATE test_helpers_query SET "id" = :id, "user" = """
""":user, "child" = :child, "active" = :active WHERE id = :id"""
)
@ -48,7 +49,7 @@ child_json = json.dumps(
"child": {"id": 3, "name": "myname", "value": "myvalue"},
}
)
test_dict = {"id": 1, "child": child_json, "active": True}
test_dict = {"id": 1, "user": "userid", "child": child_json, "active": True}
@pytest.mark.asyncio
@ -57,6 +58,7 @@ async def test_helpers_model_to_dict():
assert d.get("id") == test_data.id
assert d.get("active") == test_data.active
assert d.get("child") == child_json
assert d.get("user") == test_data.user
assert d == test_dict