mirror of
https://github.com/lnbits/lnbits.git
synced 2025-09-26 20:06:17 +02:00
support cockroachdb.
This commit is contained in:
@@ -5,7 +5,7 @@ import importlib
|
|||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .db import SQLITE, POSTGRES
|
from .db import SQLITE, POSTGRES, COCKROACH
|
||||||
from .core import db as core_db, migrations as core_migrations
|
from .core import db as core_db, migrations as core_migrations
|
||||||
from .helpers import (
|
from .helpers import (
|
||||||
get_valid_extensions,
|
get_valid_extensions,
|
||||||
@@ -83,7 +83,7 @@ async def migrate_databases():
|
|||||||
exists = await conn.fetchone(
|
exists = await conn.fetchone(
|
||||||
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
|
||||||
)
|
)
|
||||||
elif conn.type == POSTGRES:
|
elif conn.type in {POSTGRES, COCKROACH}:
|
||||||
exists = await conn.fetchone(
|
exists = await conn.fetchone(
|
||||||
"SELECT * FROM information_schema.tables WHERE table_name = 'dbversions'"
|
"SELECT * FROM information_schema.tables WHERE table_name = 'dbversions'"
|
||||||
)
|
)
|
||||||
|
@@ -5,7 +5,7 @@ from typing import List, Optional, Dict, Any
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from lnbits import bolt11
|
from lnbits import bolt11
|
||||||
from lnbits.db import Connection, POSTGRES
|
from lnbits.db import Connection, POSTGRES, COCKROACH
|
||||||
from lnbits.settings import DEFAULT_WALLET_NAME
|
from lnbits.settings import DEFAULT_WALLET_NAME
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
@@ -221,6 +221,8 @@ async def get_payments(
|
|||||||
if since != None:
|
if since != None:
|
||||||
if db.type == POSTGRES:
|
if db.type == POSTGRES:
|
||||||
clause.append("time > to_timestamp(?)")
|
clause.append("time > to_timestamp(?)")
|
||||||
|
elif db.type == COCKROACH:
|
||||||
|
clause.append("time > cast(? AS timestamp)")
|
||||||
else:
|
else:
|
||||||
clause.append("time > ?")
|
clause.append("time > ?")
|
||||||
args.append(since)
|
args.append(since)
|
||||||
|
40
lnbits/db.py
40
lnbits/db.py
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import trio
|
import trio
|
||||||
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy import create_engine # type: ignore
|
from sqlalchemy import create_engine # type: ignore
|
||||||
@@ -9,6 +10,7 @@ from sqlalchemy_aio.base import AsyncConnection # type: ignore
|
|||||||
from .settings import LNBITS_DATA_FOLDER, LNBITS_DATABASE_URL
|
from .settings import LNBITS_DATA_FOLDER, LNBITS_DATABASE_URL
|
||||||
|
|
||||||
POSTGRES = "POSTGRES"
|
POSTGRES = "POSTGRES"
|
||||||
|
COCKROACH = "COCKROACH"
|
||||||
SQLITE = "SQLITE"
|
SQLITE = "SQLITE"
|
||||||
|
|
||||||
|
|
||||||
@@ -17,7 +19,7 @@ class Compat:
|
|||||||
schema: Optional[str] = "<inherited>"
|
schema: Optional[str] = "<inherited>"
|
||||||
|
|
||||||
def interval_seconds(self, seconds: int) -> str:
|
def interval_seconds(self, seconds: int) -> str:
|
||||||
if self.type == POSTGRES:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
return f"interval '{seconds} seconds'"
|
return f"interval '{seconds} seconds'"
|
||||||
elif self.type == SQLITE:
|
elif self.type == SQLITE:
|
||||||
return f"{seconds}"
|
return f"{seconds}"
|
||||||
@@ -25,7 +27,7 @@ class Compat:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def timestamp_now(self) -> str:
|
def timestamp_now(self) -> str:
|
||||||
if self.type == POSTGRES:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
return "now()"
|
return "now()"
|
||||||
elif self.type == SQLITE:
|
elif self.type == SQLITE:
|
||||||
return "(strftime('%s', 'now'))"
|
return "(strftime('%s', 'now'))"
|
||||||
@@ -33,7 +35,7 @@ class Compat:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def serial_primary_key(self) -> str:
|
def serial_primary_key(self) -> str:
|
||||||
if self.type == POSTGRES:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
return "SERIAL PRIMARY KEY"
|
return "SERIAL PRIMARY KEY"
|
||||||
elif self.type == SQLITE:
|
elif self.type == SQLITE:
|
||||||
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
return "INTEGER PRIMARY KEY AUTOINCREMENT"
|
||||||
@@ -41,7 +43,7 @@ class Compat:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def references_schema(self) -> str:
|
def references_schema(self) -> str:
|
||||||
if self.type == POSTGRES:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
return f"{self.schema}."
|
return f"{self.schema}."
|
||||||
elif self.type == SQLITE:
|
elif self.type == SQLITE:
|
||||||
return ""
|
return ""
|
||||||
@@ -57,7 +59,7 @@ class Connection(Compat):
|
|||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
|
||||||
def rewrite_query(self, query) -> str:
|
def rewrite_query(self, query) -> str:
|
||||||
if self.type == POSTGRES:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
query = query.replace("%", "%%")
|
query = query.replace("%", "%%")
|
||||||
query = query.replace("?", "%s")
|
query = query.replace("?", "%s")
|
||||||
return query
|
return query
|
||||||
@@ -82,16 +84,30 @@ class Database(Compat):
|
|||||||
|
|
||||||
if LNBITS_DATABASE_URL:
|
if LNBITS_DATABASE_URL:
|
||||||
database_uri = LNBITS_DATABASE_URL
|
database_uri = LNBITS_DATABASE_URL
|
||||||
self.type = POSTGRES
|
|
||||||
|
if database_uri.startswith("cockroachdb://"):
|
||||||
|
self.type = COCKROACH
|
||||||
|
else:
|
||||||
|
self.type = POSTGRES
|
||||||
|
|
||||||
import psycopg2 # type: ignore
|
import psycopg2 # type: ignore
|
||||||
|
|
||||||
DEC2FLOAT = psycopg2.extensions.new_type(
|
psycopg2.extensions.register_type(
|
||||||
psycopg2.extensions.DECIMAL.values,
|
psycopg2.extensions.new_type(
|
||||||
"DEC2FLOAT",
|
psycopg2.extensions.DECIMAL.values,
|
||||||
lambda value, curs: float(value) if value is not None else None,
|
"DEC2FLOAT",
|
||||||
|
lambda value, curs: float(value) if value is not None else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
psycopg2.extensions.register_type(
|
||||||
|
psycopg2.extensions.new_type(
|
||||||
|
psycopg2.extensions.TIME.values + psycopg2.extensions.DATE.values,
|
||||||
|
"DATE2INT",
|
||||||
|
lambda value, curs: time.mktime(value.timetuple())
|
||||||
|
if value is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
psycopg2.extensions.register_type(DEC2FLOAT)
|
|
||||||
else:
|
else:
|
||||||
self.path = os.path.join(LNBITS_DATA_FOLDER, f"{self.name}.sqlite3")
|
self.path = os.path.join(LNBITS_DATA_FOLDER, f"{self.name}.sqlite3")
|
||||||
database_uri = f"sqlite:///{self.path}"
|
database_uri = f"sqlite:///{self.path}"
|
||||||
@@ -115,7 +131,7 @@ class Database(Compat):
|
|||||||
wconn = Connection(conn, txn, self.type, self.name, self.schema)
|
wconn = Connection(conn, txn, self.type, self.name, self.schema)
|
||||||
|
|
||||||
if self.schema:
|
if self.schema:
|
||||||
if self.type == POSTGRES:
|
if self.type in {POSTGRES, COCKROACH}:
|
||||||
await wconn.execute(
|
await wconn.execute(
|
||||||
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
|
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user