support cockroachdb.

This commit is contained in:
fiatjaf
2021-07-02 18:32:58 -03:00
parent 2466cd59db
commit ffadce02b0
3 changed files with 33 additions and 15 deletions

View File

@@ -5,7 +5,7 @@ import importlib
import re import re
import os import os
from .db import SQLITE, POSTGRES from .db import SQLITE, POSTGRES, COCKROACH
from .core import db as core_db, migrations as core_migrations from .core import db as core_db, migrations as core_migrations
from .helpers import ( from .helpers import (
get_valid_extensions, get_valid_extensions,
@@ -83,7 +83,7 @@ async def migrate_databases():
exists = await conn.fetchone( exists = await conn.fetchone(
"SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'" "SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'"
) )
elif conn.type == POSTGRES: elif conn.type in {POSTGRES, COCKROACH}:
exists = await conn.fetchone( exists = await conn.fetchone(
"SELECT * FROM information_schema.tables WHERE table_name = 'dbversions'" "SELECT * FROM information_schema.tables WHERE table_name = 'dbversions'"
) )

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Dict, Any
from urllib.parse import urlparse from urllib.parse import urlparse
from lnbits import bolt11 from lnbits import bolt11
from lnbits.db import Connection, POSTGRES from lnbits.db import Connection, POSTGRES, COCKROACH
from lnbits.settings import DEFAULT_WALLET_NAME from lnbits.settings import DEFAULT_WALLET_NAME
from . import db from . import db
@@ -221,6 +221,8 @@ async def get_payments(
if since != None: if since != None:
if db.type == POSTGRES: if db.type == POSTGRES:
clause.append("time > to_timestamp(?)") clause.append("time > to_timestamp(?)")
elif db.type == COCKROACH:
clause.append("time > cast(? AS timestamp)")
else: else:
clause.append("time > ?") clause.append("time > ?")
args.append(since) args.append(since)

View File

@@ -1,5 +1,6 @@
import os import os
import trio import trio
import time
from typing import Optional from typing import Optional
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sqlalchemy import create_engine # type: ignore from sqlalchemy import create_engine # type: ignore
@@ -9,6 +10,7 @@ from sqlalchemy_aio.base import AsyncConnection # type: ignore
from .settings import LNBITS_DATA_FOLDER, LNBITS_DATABASE_URL from .settings import LNBITS_DATA_FOLDER, LNBITS_DATABASE_URL
POSTGRES = "POSTGRES" POSTGRES = "POSTGRES"
COCKROACH = "COCKROACH"
SQLITE = "SQLITE" SQLITE = "SQLITE"
@@ -17,7 +19,7 @@ class Compat:
schema: Optional[str] = "<inherited>" schema: Optional[str] = "<inherited>"
def interval_seconds(self, seconds: int) -> str: def interval_seconds(self, seconds: int) -> str:
if self.type == POSTGRES: if self.type in {POSTGRES, COCKROACH}:
return f"interval '{seconds} seconds'" return f"interval '{seconds} seconds'"
elif self.type == SQLITE: elif self.type == SQLITE:
return f"{seconds}" return f"{seconds}"
@@ -25,7 +27,7 @@ class Compat:
@property @property
def timestamp_now(self) -> str: def timestamp_now(self) -> str:
if self.type == POSTGRES: if self.type in {POSTGRES, COCKROACH}:
return "now()" return "now()"
elif self.type == SQLITE: elif self.type == SQLITE:
return "(strftime('%s', 'now'))" return "(strftime('%s', 'now'))"
@@ -33,7 +35,7 @@ class Compat:
@property @property
def serial_primary_key(self) -> str: def serial_primary_key(self) -> str:
if self.type == POSTGRES: if self.type in {POSTGRES, COCKROACH}:
return "SERIAL PRIMARY KEY" return "SERIAL PRIMARY KEY"
elif self.type == SQLITE: elif self.type == SQLITE:
return "INTEGER PRIMARY KEY AUTOINCREMENT" return "INTEGER PRIMARY KEY AUTOINCREMENT"
@@ -41,7 +43,7 @@ class Compat:
@property @property
def references_schema(self) -> str: def references_schema(self) -> str:
if self.type == POSTGRES: if self.type in {POSTGRES, COCKROACH}:
return f"{self.schema}." return f"{self.schema}."
elif self.type == SQLITE: elif self.type == SQLITE:
return "" return ""
@@ -57,7 +59,7 @@ class Connection(Compat):
self.schema = schema self.schema = schema
def rewrite_query(self, query) -> str: def rewrite_query(self, query) -> str:
if self.type == POSTGRES: if self.type in {POSTGRES, COCKROACH}:
query = query.replace("%", "%%") query = query.replace("%", "%%")
query = query.replace("?", "%s") query = query.replace("?", "%s")
return query return query
@@ -82,16 +84,30 @@ class Database(Compat):
if LNBITS_DATABASE_URL: if LNBITS_DATABASE_URL:
database_uri = LNBITS_DATABASE_URL database_uri = LNBITS_DATABASE_URL
self.type = POSTGRES
if database_uri.startswith("cockroachdb://"):
self.type = COCKROACH
else:
self.type = POSTGRES
import psycopg2 # type: ignore import psycopg2 # type: ignore
DEC2FLOAT = psycopg2.extensions.new_type( psycopg2.extensions.register_type(
psycopg2.extensions.DECIMAL.values, psycopg2.extensions.new_type(
"DEC2FLOAT", psycopg2.extensions.DECIMAL.values,
lambda value, curs: float(value) if value is not None else None, "DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
psycopg2.extensions.register_type(
psycopg2.extensions.new_type(
psycopg2.extensions.TIME.values + psycopg2.extensions.DATE.values,
"DATE2INT",
lambda value, curs: time.mktime(value.timetuple())
if value is not None
else None,
)
) )
psycopg2.extensions.register_type(DEC2FLOAT)
else: else:
self.path = os.path.join(LNBITS_DATA_FOLDER, f"{self.name}.sqlite3") self.path = os.path.join(LNBITS_DATA_FOLDER, f"{self.name}.sqlite3")
database_uri = f"sqlite:///{self.path}" database_uri = f"sqlite:///{self.path}"
@@ -115,7 +131,7 @@ class Database(Compat):
wconn = Connection(conn, txn, self.type, self.name, self.schema) wconn = Connection(conn, txn, self.type, self.name, self.schema)
if self.schema: if self.schema:
if self.type == POSTGRES: if self.type in {POSTGRES, COCKROACH}:
await wconn.execute( await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}" f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
) )