diff --git a/.github/workflows/migrations.yml b/.github/workflows/migrations.yml index 45de97277..08557bc1e 100644 --- a/.github/workflows/migrations.yml +++ b/.github/workflows/migrations.yml @@ -38,12 +38,14 @@ jobs: ./venv/bin/python -m pip install --upgrade pip ./venv/bin/pip install -r requirements.txt ./venv/bin/pip install pytest pytest-asyncio pytest-cov requests mock + sudo apt install unzip - name: Run migrations run: | rm -rf ./data mkdir -p ./data export LNBITS_DATA_FOLDER="./data" + unzip tests/data/mock_data.zip -d ./data timeout 5s ./venv/bin/uvicorn lnbits.__main__:app --host 0.0.0.0 --port 5001 || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi export LNBITS_DATABASE_URL="postgres://postgres:postgres@0.0.0.0:5432/postgres" timeout 5s ./venv/bin/uvicorn lnbits.__main__:app --host 0.0.0.0 --port 5001 || code=$?; if [[ $code -ne 124 && $code -ne 0 ]]; then exit $code; fi - ./venv/bin/python tools/conv.py --dont-ignore-missing + ./venv/bin/python tools/conv.py diff --git a/.github/workflows/regtest.yml b/.github/workflows/regtest.yml index 7883cf199..f26e6c386 100644 --- a/.github/workflows/regtest.yml +++ b/.github/workflows/regtest.yml @@ -19,17 +19,11 @@ jobs: docker build -t lnbits-legend . git clone https://github.com/lnbits/legend-regtest-enviroment.git docker cd docker - source docker-scripts.sh - lnbits-regtest-start - echo "sleeping 60 seconds" - sleep 60 - echo "continue" - lnbits-regtest-init - bitcoin-cli-sim -generate 1 - lncli-sim 1 listpeers + chmod +x ./tests + ./tests sudo chmod -R a+rwx . - name: Install dependencies - env: + env: VIRTUAL_ENV: ./venv PATH: ${{ env.VIRTUAL_ENV }}/bin:${{ env.PATH }} run: | @@ -37,7 +31,7 @@ jobs: ./venv/bin/python -m pip install --upgrade pip ./venv/bin/pip install -r requirements.txt ./venv/bin/pip install pylightning - ./venv/bin/pip install pytest pytest-asyncio pytest-cov requests mock + ./venv/bin/pip install pytest pytest-asyncio pytest-cov requests mock - name: Run tests env: PYTHONUNBUFFERED: 1 @@ -66,17 +60,11 @@ jobs: docker build -t lnbits-legend . git clone https://github.com/lnbits/legend-regtest-enviroment.git docker cd docker - source docker-scripts.sh - lnbits-regtest-start - echo "sleeping 60 seconds" - sleep 60 - echo "continue" - lnbits-regtest-init - bitcoin-cli-sim -generate 1 - lncli-sim 1 listpeers + chmod +x ./tests + ./tests sudo chmod -R a+rwx . - name: Install dependencies - env: + env: VIRTUAL_ENV: ./venv PATH: ${{ env.VIRTUAL_ENV }}/bin:${{ env.PATH }} run: | @@ -84,7 +72,7 @@ jobs: ./venv/bin/python -m pip install --upgrade pip ./venv/bin/pip install -r requirements.txt ./venv/bin/pip install pylightning - ./venv/bin/pip install pytest pytest-asyncio pytest-cov requests mock + ./venv/bin/pip install pytest pytest-asyncio pytest-cov requests mock - name: Run tests env: PYTHONUNBUFFERED: 1 @@ -94,4 +82,4 @@ jobs: CLIGHTNING_RPC: docker/data/clightning-1/regtest/lightning-rpc run: | sudo chmod -R a+rwx . && rm -rf ./data && mkdir -p ./data - make test-real-wallet \ No newline at end of file + make test-real-wallet diff --git a/.gitignore b/.gitignore index b70f620d5..c3b3225fb 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,7 @@ __pycache__ .webassets-cache htmlcov test-reports -tests/data +tests/data/*.sqlite3 *.swo *.swp diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 86f7ecff9..44666ce16 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -23,6 +23,7 @@ from lnbits.settings import ( SERVICE_FEE, ) +from ...helpers import get_valid_extensions from ..crud import ( create_account, create_wallet, @@ -66,6 +67,14 @@ async def extensions( HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension." ) + # check if extension exists + if extension_to_enable or extension_to_disable: + ext = extension_to_enable or extension_to_disable + if ext not in [e.code for e in get_valid_extensions()]: + raise HTTPException( + HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist." + ) + if extension_to_enable: logger.info(f"Enabling extension: {extension_to_enable} for user {user.id}") await update_user_extension( diff --git a/lnbits/extensions/tpos/crud.py b/lnbits/extensions/tpos/crud.py index 8f071d8cc..94e2c0068 100644 --- a/lnbits/extensions/tpos/crud.py +++ b/lnbits/extensions/tpos/crud.py @@ -30,7 +30,7 @@ async def create_tpos(wallet_id: str, data: CreateTposData) -> TPoS: async def get_tpos(tpos_id: str) -> Optional[TPoS]: row = await db.fetchone("SELECT * FROM tpos.tposs WHERE id = ?", (tpos_id,)) - return TPoS.from_row(row) if row else None + return TPoS(**row) if row else None async def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: @@ -42,7 +42,7 @@ async def get_tposs(wallet_ids: Union[str, List[str]]) -> List[TPoS]: f"SELECT * FROM tpos.tposs WHERE wallet IN ({q})", (*wallet_ids,) ) - return [TPoS.from_row(row) for row in rows] + return [TPoS(**row) for row in rows] async def delete_tpos(tpos_id: str) -> None: diff --git a/lnbits/extensions/tpos/models.py b/lnbits/extensions/tpos/models.py index 6a2ff1d2c..36bca79be 100644 --- a/lnbits/extensions/tpos/models.py +++ b/lnbits/extensions/tpos/models.py @@ -1,13 +1,15 @@ from sqlite3 import Row +from typing import Optional +from fastapi import Query from pydantic import BaseModel class CreateTposData(BaseModel): name: str currency: str - tip_options: str - tip_wallet: str + tip_options: str = Query(None) + tip_wallet: str = Query(None) class TPoS(BaseModel): @@ -15,8 +17,8 @@ class TPoS(BaseModel): wallet: str name: str currency: str - tip_options: str - tip_wallet: str + tip_options: Optional[str] + tip_wallet: Optional[str] @classmethod def from_row(cls, row: Row) -> "TPoS": diff --git a/lnbits/extensions/tpos/tasks.py b/lnbits/extensions/tpos/tasks.py index 01c11428d..af9663cc9 100644 --- a/lnbits/extensions/tpos/tasks.py +++ b/lnbits/extensions/tpos/tasks.py @@ -26,7 +26,6 @@ async def on_invoice_paid(payment: Payment) -> None: # now we make some special internal transfers (from no one to the receiver) tpos = await get_tpos(payment.extra.get("tposId")) - tipAmount = payment.extra.get("tipAmount") if tipAmount is None: @@ -34,6 +33,7 @@ async def on_invoice_paid(payment: Payment) -> None: return tipAmount = tipAmount * 1000 + amount = payment.amount - tipAmount # mark the original payment with one extra key, "splitted" # (this prevents us from doing this process again and it's informative) @@ -41,13 +41,13 @@ async def on_invoice_paid(payment: Payment) -> None: await core_db.execute( """ UPDATE apipayments - SET extra = ?, amount = amount - ? + SET extra = ?, amount = ? WHERE hash = ? AND checking_id NOT LIKE 'internal_%' """, ( json.dumps(dict(**payment.extra, tipSplitted=True)), - tipAmount, + amount, payment.payment_hash, ), ) @@ -60,7 +60,7 @@ async def on_invoice_paid(payment: Payment) -> None: payment_request="", payment_hash=payment.payment_hash, amount=tipAmount, - memo=payment.memo, + memo=f"Tip for {payment.memo}", pending=False, extra={"tipSplitted": True}, ) diff --git a/lnbits/extensions/tpos/templates/tpos/index.html b/lnbits/extensions/tpos/templates/tpos/index.html index 76f330007..edbb2aa87 100644 --- a/lnbits/extensions/tpos/templates/tpos/index.html +++ b/lnbits/extensions/tpos/templates/tpos/index.html @@ -54,8 +54,8 @@ > - {{ (col.name == 'tip_options' ? JSON.parse(col.value).join(", ") - : col.value) }} + {{ (col.name == 'tip_options' && col.value ? + JSON.parse(col.value).join(", ") : col.value) }}

{% raw %}{{ famount }}{% endraw %}

- {% raw %}{{ fsat }}{% endraw %} sat + {% raw %}{{ fsat }} + sat + ( + {{ tipAmountSat }} tip) + {% endraw %}
@@ -310,7 +313,6 @@ return Math.ceil((this.tipAmount / this.exchangeRate) * 100000000) }, fsat: function () { - console.log('sat', this.sat, LNbits.utils.formatSat(this.sat)) return LNbits.utils.formatSat(this.sat) } }, @@ -362,7 +364,6 @@ showInvoice: function () { var self = this var dialog = this.invoiceDialog - console.log(this.sat, this.tposId) axios .post('/tpos/api/v1/tposs/' + this.tposId + '/invoices', null, { params: { diff --git a/lnbits/extensions/tpos/views_api.py b/lnbits/extensions/tpos/views_api.py index 9567f98a6..9609956ec 100644 --- a/lnbits/extensions/tpos/views_api.py +++ b/lnbits/extensions/tpos/views_api.py @@ -17,7 +17,7 @@ from .models import CreateTposData @tpos_ext.get("/api/v1/tposs", status_code=HTTPStatus.OK) async def api_tposs( - all_wallets: bool = Query(None), wallet: WalletTypeInfo = Depends(get_key_type) + all_wallets: bool = Query(False), wallet: WalletTypeInfo = Depends(get_key_type) ): wallet_ids = [wallet.wallet.id] if all_wallets: @@ -63,6 +63,9 @@ async def api_tpos_create_invoice( status_code=HTTPStatus.NOT_FOUND, detail="TPoS does not exist." ) + if tipAmount: + amount += tipAmount + try: payment_hash, payment_request = await create_invoice( wallet_id=tpos.wallet, diff --git a/lnbits/extensions/withdraw/README.md b/lnbits/extensions/withdraw/README.md index 0e5939fdb..7bf7c232c 100644 --- a/lnbits/extensions/withdraw/README.md +++ b/lnbits/extensions/withdraw/README.md @@ -26,6 +26,8 @@ LNBits Quick Vouchers allows you to easily create a batch of LNURLw's QR codes t - on details you can print the vouchers\ ![printable vouchers](https://i.imgur.com/2xLHbob.jpg) - every printed LNURLw QR code is unique, it can only be used once +3. Bonus: you can use an LNbits themed voucher, or use a custom one. There's a _template.svg_ file in `static/images` folder if you want to create your own.\ + ![voucher](https://i.imgur.com/qyQoHi3.jpg) #### Advanced diff --git a/lnbits/extensions/withdraw/crud.py b/lnbits/extensions/withdraw/crud.py index ab35fafac..9868b0570 100644 --- a/lnbits/extensions/withdraw/crud.py +++ b/lnbits/extensions/withdraw/crud.py @@ -26,9 +26,10 @@ async def create_withdraw_link( k1, open_time, usescsv, - webhook_url + webhook_url, + custom_url ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( link_id, @@ -44,6 +45,7 @@ async def create_withdraw_link( int(datetime.now().timestamp()) + data.wait_time, usescsv, data.webhook_url, + data.custom_url, ), ) link = await get_withdraw_link(link_id, 0) diff --git a/lnbits/extensions/withdraw/migrations.py b/lnbits/extensions/withdraw/migrations.py index 83f3fc242..5484277a2 100644 --- a/lnbits/extensions/withdraw/migrations.py +++ b/lnbits/extensions/withdraw/migrations.py @@ -115,3 +115,10 @@ async def m004_webhook_url(db): Adds webhook_url """ await db.execute("ALTER TABLE withdraw.withdraw_link ADD COLUMN webhook_url TEXT;") + + +async def m005_add_custom_print_design(db): + """ + Adds custom print design + """ + await db.execute("ALTER TABLE withdraw.withdraw_link ADD COLUMN custom_url TEXT;") diff --git a/lnbits/extensions/withdraw/models.py b/lnbits/extensions/withdraw/models.py index c3ca7c459..2672537fa 100644 --- a/lnbits/extensions/withdraw/models.py +++ b/lnbits/extensions/withdraw/models.py @@ -16,6 +16,7 @@ class CreateWithdrawData(BaseModel): wait_time: int = Query(..., ge=1) is_unique: bool webhook_url: str = Query(None) + custom_url: str = Query(None) class WithdrawLink(BaseModel): @@ -34,6 +35,7 @@ class WithdrawLink(BaseModel): usescsv: str = Query(None) number: int = Query(0) webhook_url: str = Query(None) + custom_url: str = Query(None) @property def is_spent(self) -> bool: diff --git a/lnbits/extensions/withdraw/static/js/index.js b/lnbits/extensions/withdraw/static/js/index.js index 3f484debf..1982d6845 100644 --- a/lnbits/extensions/withdraw/static/js/index.js +++ b/lnbits/extensions/withdraw/static/js/index.js @@ -20,9 +20,12 @@ var mapWithdrawLink = function (obj) { obj.uses_left = obj.uses - obj.used obj.print_url = [locationPath, 'print/', obj.id].join('') obj.withdraw_url = [locationPath, obj.id].join('') + obj._data.use_custom = Boolean(obj.custom_url) return obj } +const CUSTOM_URL = '/static/images/default_voucher.png' + new Vue({ el: '#vue', mixins: [windowMixin], @@ -59,13 +62,15 @@ new Vue({ secondMultiplier: 'seconds', secondMultiplierOptions: ['seconds', 'minutes', 'hours'], data: { - is_unique: false + is_unique: false, + use_custom: false } }, simpleformDialog: { show: false, data: { is_unique: true, + use_custom: true, title: 'Vouchers', min_withdrawable: 0, wait_time: 1 @@ -106,12 +111,14 @@ new Vue({ }, closeFormDialog: function () { this.formDialog.data = { - is_unique: false + is_unique: false, + use_custom: false } }, simplecloseFormDialog: function () { this.simpleformDialog.data = { - is_unique: false + is_unique: false, + use_custom: false } }, openQrCodeDialog: function (linkId) { @@ -133,6 +140,9 @@ new Vue({ id: this.formDialog.data.wallet }) var data = _.omit(this.formDialog.data, 'wallet') + if (data.use_custom && !data?.custom_url) { + data.custom_url = CUSTOM_URL + } data.wait_time = data.wait_time * @@ -141,7 +151,6 @@ new Vue({ minutes: 60, hours: 3600 }[this.formDialog.secondMultiplier] - if (data.id) { this.updateWithdrawLink(wallet, data) } else { @@ -159,6 +168,10 @@ new Vue({ data.title = 'vouchers' data.is_unique = true + if (data.use_custom && !data?.custom_url) { + data.custom_url = '/static/images/default_voucher.png' + } + if (data.id) { this.updateWithdrawLink(wallet, data) } else { @@ -181,7 +194,8 @@ new Vue({ 'uses', 'wait_time', 'is_unique', - 'webhook_url' + 'webhook_url', + 'custom_url' ) ) .then(function (response) { diff --git a/lnbits/extensions/withdraw/templates/withdraw/index.html b/lnbits/extensions/withdraw/templates/withdraw/index.html index 99aa03b2d..9ff428a1d 100644 --- a/lnbits/extensions/withdraw/templates/withdraw/index.html +++ b/lnbits/extensions/withdraw/templates/withdraw/index.html @@ -217,6 +217,32 @@ label="Webhook URL (optional)" hint="A URL to be called whenever this link gets used." > + + + + + + + Use a custom voucher design + You can use an LNbits voucher design or a custom + one + + + + @@ -303,6 +329,32 @@ :default="1" label="Number of vouchers" > + + + + + + + Use a custom voucher design + You can use an LNbits voucher design or a custom + one + + + +
+
+ {% for page in link %} + + {% for one in page %} +
+ ... + {{ amt }} sats +
+ +
+
+ {% endfor %} +
+ {% endfor %} +
+
+{% endblock %} {% block styles %} + +{% endblock %} {% block scripts %} + +{% endblock %} diff --git a/lnbits/extensions/withdraw/views.py b/lnbits/extensions/withdraw/views.py index 1f059a4b0..97fb12717 100644 --- a/lnbits/extensions/withdraw/views.py +++ b/lnbits/extensions/withdraw/views.py @@ -99,6 +99,18 @@ async def print_qr(request: Request, link_id): page_link = list(chunks(links, 2)) linked = list(chunks(page_link, 5)) + if link.custom_url: + return withdraw_renderer().TemplateResponse( + "withdraw/print_qr_custom.html", + { + "request": request, + "link": page_link, + "unique": True, + "custom_url": link.custom_url, + "amt": link.max_withdrawable, + }, + ) + return withdraw_renderer().TemplateResponse( "withdraw/print_qr.html", {"request": request, "link": linked, "unique": True} ) diff --git a/lnbits/static/images/default_voucher.png b/lnbits/static/images/default_voucher.png new file mode 100644 index 000000000..8462b285b Binary files /dev/null and b/lnbits/static/images/default_voucher.png differ diff --git a/lnbits/static/images/voucher_template.svg b/lnbits/static/images/voucher_template.svg new file mode 100644 index 000000000..4347758f2 --- /dev/null +++ b/lnbits/static/images/voucher_template.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/tests/core/views/test_generic.py b/tests/core/views/test_generic.py index 1dff6f012..6e6354d1b 100644 --- a/tests/core/views/test_generic.py +++ b/tests/core/views/test_generic.py @@ -92,8 +92,69 @@ async def test_get_wallet_with_user_and_wallet(client, to_user, to_wallet): # check GET /wallet: wrong wallet and user, expect 204 @pytest.mark.asyncio -async def test_get_wallet_with_user_and_wrong_wallet(client, to_user, to_wallet): +async def test_get_wallet_with_user_and_wrong_wallet(client, to_user): response = await client.get("wallet", params={"usr": to_user.id, "wal": "1"}) assert response.status_code == 204, ( str(response.url) + " " + str(response.status_code) ) + + +# check GET /extensions: extensions list +@pytest.mark.asyncio +async def test_get_extensions(client, to_user): + response = await client.get("extensions", params={"usr": to_user.id}) + assert response.status_code == 200, ( + str(response.url) + " " + str(response.status_code) + ) + + +# check GET /extensions: extensions list wrong user, expect 204 +@pytest.mark.asyncio +async def test_get_extensions_wrong_user(client, to_user): + response = await client.get("extensions", params={"usr": "1"}) + assert response.status_code == 204, ( + str(response.url) + " " + str(response.status_code) + ) + + +# check GET /extensions: no user given, expect code 204 no content +@pytest.mark.asyncio +async def test_get_extensions_no_user(client): + response = await client.get("extensions") + assert response.status_code == 204, ( # no content + str(response.url) + " " + str(response.status_code) + ) + + +# check GET /extensions: enable extension +@pytest.mark.asyncio +async def test_get_extensions_enable(client, to_user): + response = await client.get( + "extensions", params={"usr": to_user.id, "enable": "lnurlp"} + ) + assert response.status_code == 200, ( + str(response.url) + " " + str(response.status_code) + ) + + +# check GET /extensions: enable nonexistent extension, expect code 400 bad request +@pytest.mark.asyncio +async def test_get_extensions_enable_nonexistent_extension(client, to_user): + response = await client.get( + "extensions", params={"usr": to_user.id, "enable": "12341234"} + ) + assert response.status_code == 400, ( + str(response.url) + " " + str(response.status_code) + ) + + +# check GET /extensions: enable and disable extensions, expect code 400 bad request +@pytest.mark.asyncio +async def test_get_extensions_enable_and_disable(client, to_user): + response = await client.get( + "extensions", + params={"usr": to_user.id, "enable": "lnurlp", "disable": "lnurlp"}, + ) + assert response.status_code == 400, ( + str(response.url) + " " + str(response.status_code) + ) diff --git a/tests/data/mock_data.zip b/tests/data/mock_data.zip new file mode 100644 index 000000000..0480adc27 Binary files /dev/null and b/tests/data/mock_data.zip differ diff --git a/tools/conv.py b/tools/conv.py index 24414da15..5821f7af8 100644 --- a/tools/conv.py +++ b/tools/conv.py @@ -38,8 +38,6 @@ else: pgport = LNBITS_DATABASE_URL.split("@")[1].split(":")[1].split("/")[0] pgschema = "" -print(pgdb, pguser, pgpswd, pghost, pgport, pgschema) - def get_sqlite_cursor(sqdb) -> sqlite3: consq = sqlite3.connect(sqdb) @@ -99,8 +97,12 @@ def insert_to_pg(query, data): for d in data: try: cursor.execute(query, d) - except: - raise ValueError(f"Failed to insert {d}") + except Exception as e: + if args.ignore_errors: + print(e) + print(f"Failed to insert {d}") + else: + raise ValueError(f"Failed to insert {d}") connection.commit() cursor.close() @@ -256,9 +258,10 @@ def migrate_ext(sqlite_db_file, schema, ignore_missing=True): k1, open_time, used, - usescsv + usescsv, + webhook_url ) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); """ insert_to_pg(q, res.fetchall()) # WITHDRAW HASH CHECK @@ -316,8 +319,8 @@ def migrate_ext(sqlite_db_file, schema, ignore_missing=True): # TPOSS res = sq.execute("SELECT * FROM tposs;") q = f""" - INSERT INTO tpos.tposs (id, wallet, name, currency) - VALUES (%s, %s, %s, %s); + INSERT INTO tpos.tposs (id, wallet, name, currency, tip_wallet, tip_options) + VALUES (%s, %s, %s, %s, %s, %s); """ insert_to_pg(q, res.fetchall()) elif schema == "tipjar": @@ -512,12 +515,13 @@ def migrate_ext(sqlite_db_file, schema, ignore_missing=True): wallet, url, memo, + description, amount, time, remembers, - extra + extras ) - VALUES (%s, %s, %s, %s, %s, to_timestamp(%s), %s, %s); + VALUES (%s, %s, %s, %s, %s, %s, to_timestamp(%s), %s, %s); """ insert_to_pg(q, res.fetchall()) elif schema == "offlineshop": @@ -543,15 +547,15 @@ def migrate_ext(sqlite_db_file, schema, ignore_missing=True): # lnurldevice res = sq.execute("SELECT * FROM lnurldevices;") q = f""" - INSERT INTO lnurldevice.lnurldevices (id, key, title, wallet, currency, device, profit) - VALUES (%s, %s, %s, %s, %s, %s, %s); + INSERT INTO lnurldevice.lnurldevices (id, key, title, wallet, currency, device, profit, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, to_timestamp(%s)); """ insert_to_pg(q, res.fetchall()) # lnurldevice PAYMENT res = sq.execute("SELECT * FROM lnurldevicepayment;") q = f""" - INSERT INTO lnurldevice.lnurldevicepayment (id, deviceid, payhash, payload, pin, sats) - VALUES (%s, %s, %s, %s, %s, %s); + INSERT INTO lnurldevice.lnurldevicepayment (id, deviceid, payhash, payload, pin, sats, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, to_timestamp(%s)); """ insert_to_pg(q, res.fetchall()) elif schema == "lnurlp": @@ -710,36 +714,69 @@ def migrate_ext(sqlite_db_file, schema, ignore_missing=True): sq.close() -parser = argparse.ArgumentParser(description="Migrate data from SQLite to PostgreSQL") +parser = argparse.ArgumentParser( + description="LNbits migration tool for migrating data from SQLite to PostgreSQL" +) parser.add_argument( - dest="sqlite_file", + dest="sqlite_path", const=True, nargs="?", - help="SQLite DB to migrate from", - default="data/database.sqlite3", + help=f"SQLite DB folder *or* single extension db file to migrate. Default: {sqfolder}", + default=sqfolder, type=str, ) parser.add_argument( - "-i", - "--dont-ignore-missing", - help="Error if migration is missing for an extension.", + "-e", + "--extensions-only", + help="Migrate only extensions", required=False, default=False, - const=True, - nargs="?", - type=bool, + action="store_true", ) + +parser.add_argument( + "-s", + "--skip-missing", + help="Error if migration is missing for an extension", + required=False, + default=False, + action="store_true", +) + +parser.add_argument( + "-i", + "--ignore-errors", + help="Don't error if migration fails", + required=False, + default=False, + action="store_true", +) + args = parser.parse_args() -print(args) +print("Selected path: ", args.sqlite_path) -check_db_versions(args.sqlite_file) -migrate_core(args.sqlite_file) +if os.path.isdir(args.sqlite_path): + file = os.path.join(args.sqlite_path, "database.sqlite3") + check_db_versions(file) + if not args.extensions_only: + print(f"Migrating: {file}") + migrate_core(file) + +if os.path.isdir(args.sqlite_path): + files = [ + os.path.join(args.sqlite_path, file) for file in os.listdir(args.sqlite_path) + ] +else: + files = [args.sqlite_path] -files = os.listdir(sqfolder) for file in files: - path = f"data/{file}" - if file.startswith("ext_"): - schema = file.replace("ext_", "").split(".")[0] - print(f"Migrating: {schema}") - migrate_ext(path, schema, ignore_missing=not args.dont_ignore_missing) + filename = os.path.basename(file) + if filename.startswith("ext_"): + schema = filename.replace("ext_", "").split(".")[0] + print(f"Migrating: {file}") + migrate_ext( + file, + schema, + ignore_missing=args.skip_missing, + )