From ffadce02b08a9f31c4fbe014e858f381a515c5a5 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Fri, 2 Jul 2021 18:32:58 -0300 Subject: [PATCH] support cockroachdb. --- lnbits/commands.py | 4 ++-- lnbits/core/crud.py | 4 +++- lnbits/db.py | 40 ++++++++++++++++++++++++++++------------ 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/lnbits/commands.py b/lnbits/commands.py index 02fac66db..021d26dcd 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -5,7 +5,7 @@ import importlib import re import os -from .db import SQLITE, POSTGRES +from .db import SQLITE, POSTGRES, COCKROACH from .core import db as core_db, migrations as core_migrations from .helpers import ( get_valid_extensions, @@ -83,7 +83,7 @@ async def migrate_databases(): exists = await conn.fetchone( "SELECT * FROM sqlite_master WHERE type='table' AND name='dbversions'" ) - elif conn.type == POSTGRES: + elif conn.type in {POSTGRES, COCKROACH}: exists = await conn.fetchone( "SELECT * FROM information_schema.tables WHERE table_name = 'dbversions'" ) diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index ab9026e3a..8135dc883 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -5,7 +5,7 @@ from typing import List, Optional, Dict, Any from urllib.parse import urlparse from lnbits import bolt11 -from lnbits.db import Connection, POSTGRES +from lnbits.db import Connection, POSTGRES, COCKROACH from lnbits.settings import DEFAULT_WALLET_NAME from . import db @@ -221,6 +221,8 @@ async def get_payments( if since != None: if db.type == POSTGRES: clause.append("time > to_timestamp(?)") + elif db.type == COCKROACH: + clause.append("time > cast(? AS timestamp)") else: clause.append("time > ?") args.append(since) diff --git a/lnbits/db.py b/lnbits/db.py index 11548ac53..3eb433388 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -1,5 +1,6 @@ import os import trio +import time from typing import Optional from contextlib import asynccontextmanager from sqlalchemy import create_engine # type: ignore @@ -9,6 +10,7 @@ from sqlalchemy_aio.base import AsyncConnection # type: ignore from .settings import LNBITS_DATA_FOLDER, LNBITS_DATABASE_URL POSTGRES = "POSTGRES" +COCKROACH = "COCKROACH" SQLITE = "SQLITE" @@ -17,7 +19,7 @@ class Compat: schema: Optional[str] = "" def interval_seconds(self, seconds: int) -> str: - if self.type == POSTGRES: + if self.type in {POSTGRES, COCKROACH}: return f"interval '{seconds} seconds'" elif self.type == SQLITE: return f"{seconds}" @@ -25,7 +27,7 @@ class Compat: @property def timestamp_now(self) -> str: - if self.type == POSTGRES: + if self.type in {POSTGRES, COCKROACH}: return "now()" elif self.type == SQLITE: return "(strftime('%s', 'now'))" @@ -33,7 +35,7 @@ class Compat: @property def serial_primary_key(self) -> str: - if self.type == POSTGRES: + if self.type in {POSTGRES, COCKROACH}: return "SERIAL PRIMARY KEY" elif self.type == SQLITE: return "INTEGER PRIMARY KEY AUTOINCREMENT" @@ -41,7 +43,7 @@ class Compat: @property def references_schema(self) -> str: - if self.type == POSTGRES: + if self.type in {POSTGRES, COCKROACH}: return f"{self.schema}." elif self.type == SQLITE: return "" @@ -57,7 +59,7 @@ class Connection(Compat): self.schema = schema def rewrite_query(self, query) -> str: - if self.type == POSTGRES: + if self.type in {POSTGRES, COCKROACH}: query = query.replace("%", "%%") query = query.replace("?", "%s") return query @@ -82,16 +84,30 @@ class Database(Compat): if LNBITS_DATABASE_URL: database_uri = LNBITS_DATABASE_URL - self.type = POSTGRES + + if database_uri.startswith("cockroachdb://"): + self.type = COCKROACH + else: + self.type = POSTGRES import psycopg2 # type: ignore - DEC2FLOAT = psycopg2.extensions.new_type( - psycopg2.extensions.DECIMAL.values, - "DEC2FLOAT", - lambda value, curs: float(value) if value is not None else None, + psycopg2.extensions.register_type( + psycopg2.extensions.new_type( + psycopg2.extensions.DECIMAL.values, + "DEC2FLOAT", + lambda value, curs: float(value) if value is not None else None, + ) + ) + psycopg2.extensions.register_type( + psycopg2.extensions.new_type( + psycopg2.extensions.TIME.values + psycopg2.extensions.DATE.values, + "DATE2INT", + lambda value, curs: time.mktime(value.timetuple()) + if value is not None + else None, + ) ) - psycopg2.extensions.register_type(DEC2FLOAT) else: self.path = os.path.join(LNBITS_DATA_FOLDER, f"{self.name}.sqlite3") database_uri = f"sqlite:///{self.path}" @@ -115,7 +131,7 @@ class Database(Compat): wconn = Connection(conn, txn, self.type, self.name, self.schema) if self.schema: - if self.type == POSTGRES: + if self.type in {POSTGRES, COCKROACH}: await wconn.execute( f"CREATE SCHEMA IF NOT EXISTS {self.schema}" )