From 93965bc5b646bd5d2a79d93ea3f90fa56bdae9f1 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Fri, 24 May 2024 00:23:32 +0300 Subject: [PATCH] [test] `webpush_api` endpoints (#2534) * test: webpush_api endpoints * fix: SQL quote for `user` --- lnbits/core/crud.py | 12 +++-- lnbits/core/views/webpush_api.py | 3 +- tests/api/test_webpush_api.py | 89 ++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 tests/api/test_webpush_api.py diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 6e3df6293..971994888 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -1287,7 +1287,7 @@ async def create_webpush_subscription( ) -> WebPushSubscription: await db.execute( """ - INSERT INTO webpush_subscriptions (endpoint, user, data, host) + INSERT INTO webpush_subscriptions (endpoint, "user", data, host) VALUES (?, ?, ?, ?) """, ( @@ -1302,17 +1302,19 @@ async def create_webpush_subscription( return subscription -async def delete_webpush_subscription(endpoint: str, user: str) -> None: - await db.execute( +async def delete_webpush_subscription(endpoint: str, user: str) -> int: + resp = await db.execute( """DELETE FROM webpush_subscriptions WHERE endpoint = ? AND "user" = ?""", ( endpoint, user, ), ) + return resp.rowcount -async def delete_webpush_subscriptions(endpoint: str) -> None: - await db.execute( +async def delete_webpush_subscriptions(endpoint: str) -> int: + resp = await db.execute( "DELETE FROM webpush_subscriptions WHERE endpoint = ?", (endpoint,) ) + return resp.rowcount diff --git a/lnbits/core/views/webpush_api.py b/lnbits/core/views/webpush_api.py index 26e506fc9..66a482312 100644 --- a/lnbits/core/views/webpush_api.py +++ b/lnbits/core/views/webpush_api.py @@ -67,7 +67,8 @@ async def api_delete_webpush_subscription( endpoint = unquote( base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8") ) - await delete_webpush_subscription(endpoint, wallet.wallet.user) + count = await delete_webpush_subscription(endpoint, wallet.wallet.user) + return {"count": count} except Exception as exc: logger.debug(exc) raise HTTPException( diff --git a/tests/api/test_webpush_api.py b/tests/api/test_webpush_api.py new file mode 100644 index 000000000..36cc7c750 --- /dev/null +++ b/tests/api/test_webpush_api.py @@ -0,0 +1,89 @@ +from http import HTTPStatus + +import pytest + + +@pytest.mark.asyncio +async def test_create___bad_body(client, adminkey_headers_from): + response = await client.post( + "/api/v1/webpush", + headers=adminkey_headers_from, + json={"subscription": "bad_json"}, + ) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_create___missing_fields(client, adminkey_headers_from): + response = await client.post( + "/api/v1/webpush", + headers=adminkey_headers_from, + json={"subscription": """{"a": "x"}"""}, + ) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_create___bad_access_key(client, inkey_headers_from): + response = await client.post( + "/api/v1/webpush", + headers=inkey_headers_from, + json={"subscription": """{"a": "x"}"""}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.asyncio +async def test_delete__bad_endpoint_format(client, adminkey_headers_from): + response = await client.delete( + "/api/v1/webpush", + params={"endpoint": "https://this.should.be.base64.com"}, + headers=adminkey_headers_from, + ) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_delete__no_endpoint_param(client, adminkey_headers_from): + response = await client.delete( + "/api/v1/webpush", + headers=adminkey_headers_from, + ) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_delete__no_endpoint_found(client, adminkey_headers_from): + response = await client.delete( + "/api/v1/webpush", + params={"endpoint": "aHR0cHM6Ly9kZW1vLmxuYml0cy5jb20="}, + headers=adminkey_headers_from, + ) + assert response.status_code == HTTPStatus.OK + assert response.json()["count"] == 0 + + +@pytest.mark.asyncio +async def test_delete__bad_access_key(client, inkey_headers_from): + response = await client.delete( + "/api/v1/webpush", + headers=inkey_headers_from, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.asyncio +async def test_create_and_delete(client, adminkey_headers_from): + response = await client.post( + "/api/v1/webpush", + headers=adminkey_headers_from, + json={"subscription": """{"endpoint": "https://demo.lnbits.com"}"""}, + ) + assert response.status_code == HTTPStatus.CREATED + response = await client.delete( + "/api/v1/webpush", + params={"endpoint": "aHR0cHM6Ly9kZW1vLmxuYml0cy5jb20="}, + headers=adminkey_headers_from, + ) + assert response.status_code == HTTPStatus.OK + assert response.json()["count"] == 1