mirror of
https://github.com/lnbits/lnbits.git
synced 2025-10-11 04:52:34 +02:00
refactor: improve database migrations
This commit is contained in:
28
lnbits/db.py
28
lnbits/db.py
@@ -1,8 +1,7 @@
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
from .helpers import ExtensionManager
|
||||
from .settings import LNBITS_PATH, LNBITS_DATA_FOLDER
|
||||
from .settings import LNBITS_DATA_FOLDER
|
||||
|
||||
|
||||
class Database:
|
||||
@@ -19,16 +18,16 @@ class Database:
|
||||
self.cursor.close()
|
||||
self.connection.close()
|
||||
|
||||
def fetchall(self, query: str, values: tuple) -> list:
|
||||
def fetchall(self, query: str, values: tuple = ()) -> list:
|
||||
"""Given a query, return cursor.fetchall() rows."""
|
||||
self.cursor.execute(query, values)
|
||||
return self.cursor.fetchall()
|
||||
|
||||
def fetchone(self, query: str, values: tuple):
|
||||
def fetchone(self, query: str, values: tuple = ()):
|
||||
self.cursor.execute(query, values)
|
||||
return self.cursor.fetchone()
|
||||
|
||||
def execute(self, query: str, values: tuple) -> None:
|
||||
def execute(self, query: str, values: tuple = ()) -> None:
|
||||
"""Given a query, cursor.execute() it."""
|
||||
self.cursor.execute(query, values)
|
||||
self.connection.commit()
|
||||
@@ -41,22 +40,3 @@ def open_db(db_name: str = "database") -> Database:
|
||||
|
||||
def open_ext_db(extension_name: str) -> Database:
|
||||
return open_db(f"ext_{extension_name}")
|
||||
|
||||
|
||||
def init_databases() -> None:
|
||||
"""Creates the necessary databases if they don't exist already."""
|
||||
"""TODO: see how we can deal with migrations."""
|
||||
|
||||
schemas = [
|
||||
("database", os.path.join(LNBITS_PATH, "core", "schema.sql")),
|
||||
]
|
||||
|
||||
for extension in ExtensionManager().extensions:
|
||||
extension_path = os.path.join(LNBITS_PATH, "extensions", extension.code)
|
||||
schemas.append((f"ext_{extension.code}", os.path.join(extension_path, "schema.sql")))
|
||||
|
||||
for schema in [s for s in schemas if os.path.exists(s[1])]:
|
||||
with open_db(schema[0]) as db:
|
||||
with open(schema[1]) as schemafile:
|
||||
for stmt in schemafile.read().split(";\n\n"):
|
||||
db.execute(stmt, [])
|
||||
|
Reference in New Issue
Block a user