diff --git a/backend/scripts/onyx_redis.py b/backend/scripts/onyx_redis.py index c7eb7fbef..10eab4086 100644 --- a/backend/scripts/onyx_redis.py +++ b/backend/scripts/onyx_redis.py @@ -1,20 +1,30 @@ -# Tool to run helpful operations on Redis in production -# This is targeted for internal usage and may not have all the necessary parameters -# for general usage across custom deployments import argparse +import json import logging import sys import time from logging import getLogger from typing import cast +from uuid import UUID from redis import Redis +from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email +from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PASSWORD from onyx.configs.app_configs import REDIS_PORT +from onyx.configs.app_configs import REDIS_SSL +from onyx.db.engine import get_session_with_tenant +from onyx.db.users import get_user_by_email from onyx.redis.redis_pool import RedisPool +from shared_configs.configs import MULTI_TENANT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA + +# Tool to run helpful operations on Redis in production +# This is targeted for internal usage and may not have all the necessary parameters +# for general usage across custom deployments # Configure the logger logging.basicConfig( @@ -29,6 +39,18 @@ SCAN_ITER_COUNT = 10000 BATCH_DEFAULT = 1000 +def get_user_id(user_email: str) -> tuple[UUID, str]: + tenant_id = ( + get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA + ) + + with get_session_with_tenant(tenant_id) as session: + user = get_user_by_email(user_email, session) + if user is None: + raise ValueError(f"User not found for email: {user_email}") + return user.id, tenant_id + + def onyx_redis( command: str, batch: int, @@ -37,13 +59,14 @@ def onyx_redis( port: int, db: int, password: str | None, + user_email: str | None = None, ) -> int: pool = RedisPool.create_pool( host=host, port=port, db=db, password=password if password else "", - ssl=True, + ssl=REDIS_SSL, ssl_cert_reqs="optional", ssl_ca_certs=None, ) @@ -72,6 +95,25 @@ def onyx_redis( return purge_by_match_and_type( "*connectorsync:vespa_syncing*", "string", batch, dry_run, r ) + elif command == "get_user_token": + if not user_email: + logger.error("You must specify --user-email with get_user_token") + return 1 + token_key = get_user_token_from_redis(r, user_email) + if token_key: + print(f"Token key for user {user_email}: {token_key}") + return 0 + else: + print(f"No token found for user {user_email}") + return 2 + elif command == "delete_user_token": + if not user_email: + logger.error("You must specify --user-email with delete_user_token") + return 1 + if delete_user_token_from_redis(r, user_email, dry_run): + return 0 + else: + return 2 else: pass @@ -134,6 +176,104 @@ def purge_by_match_and_type( return 0 +def get_user_token_from_redis(r: Redis, user_email: str) -> str | None: + """ + Scans Redis keys for a user token that matches user_email or user_id fields. + Returns the token key if found, else None. + """ + user_id, tenant_id = get_user_id(user_email) + + # Scan for keys matching the auth key prefix + auth_keys = r.scan_iter(f"{REDIS_AUTH_KEY_PREFIX}*", count=SCAN_ITER_COUNT) + + matching_key = None + + for key in auth_keys: + key_str = key.decode("utf-8") + jwt_token = r.get(key_str) + + if not jwt_token: + continue + + try: + jwt_token_str = ( + jwt_token.decode("utf-8") + if isinstance(jwt_token, bytes) + else str(jwt_token) + ) + + if jwt_token_str.startswith("b'") and jwt_token_str.endswith("'"): + jwt_token_str = jwt_token_str[2:-1] # Remove b'' wrapper + + jwt_data = json.loads(jwt_token_str) + if jwt_data.get("tenant_id") == tenant_id and str( + jwt_data.get("sub") + ) == str(user_id): + matching_key = key_str + break + except json.JSONDecodeError: + logger.error(f"Failed to decode JSON for key: {key_str}") + except Exception as e: + logger.error(f"Error processing JWT for key: {key_str}. Error: {str(e)}") + + if matching_key: + return matching_key[len(REDIS_AUTH_KEY_PREFIX) :] + return None + + +def delete_user_token_from_redis( + r: Redis, user_email: str, dry_run: bool = False +) -> bool: + """ + Scans Redis keys for a user token matching user_email and deletes it if found. + Returns True if something was deleted, otherwise False. + """ + user_id, tenant_id = get_user_id(user_email) + + # Scan for keys matching the auth key prefix + auth_keys = r.scan_iter(f"{REDIS_AUTH_KEY_PREFIX}*", count=SCAN_ITER_COUNT) + matching_key = None + + for key in auth_keys: + key_str = key.decode("utf-8") + jwt_token = r.get(key_str) + + if not jwt_token: + continue + + try: + jwt_token_str = ( + jwt_token.decode("utf-8") + if isinstance(jwt_token, bytes) + else str(jwt_token) + ) + + if jwt_token_str.startswith("b'") and jwt_token_str.endswith("'"): + jwt_token_str = jwt_token_str[2:-1] # Remove b'' wrapper + + jwt_data = json.loads(jwt_token_str) + if jwt_data.get("tenant_id") == tenant_id and str( + jwt_data.get("sub") + ) == str(user_id): + matching_key = key_str + break + except json.JSONDecodeError: + logger.error(f"Failed to decode JSON for key: {key_str}") + except Exception as e: + logger.error(f"Error processing JWT for key: {key_str}. Error: {str(e)}") + + if matching_key: + if dry_run: + logger.info(f"(DRY-RUN) Would delete token key: {matching_key}") + else: + r.delete(matching_key) + logger.info(f"Deleted token for user: {user_email}") + return True + else: + logger.info(f"No token found for user: {user_email}") + return False + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Onyx Redis Manager") parser.add_argument("--command", type=str, help="Operation to run", required=True) @@ -185,6 +325,13 @@ if __name__ == "__main__": required=False, ) + parser.add_argument( + "--user-email", + type=str, + help="User email for get or delete user token", + required=False, + ) + args = parser.parse_args() exitcode = onyx_redis( command=args.command, @@ -194,5 +341,6 @@ if __name__ == "__main__": port=args.port, db=args.db, password=args.password, + user_email=args.user_email, ) sys.exit(exitcode) diff --git a/backend/scripts/onyx_vespa.py b/backend/scripts/onyx_vespa.py new file mode 100644 index 000000000..dc8ff0ff8 --- /dev/null +++ b/backend/scripts/onyx_vespa.py @@ -0,0 +1,269 @@ +""" +Vespa Debugging Tool! + +Usage: + python vespa_debug_tool.py --action [options] + +Actions: + config : Print Vespa configuration + connect : Check Vespa connectivity + list_docs : List documents + search : Search documents + update : Update a document + delete : Delete a document + get_acls : Get document ACLs + +Options: + --tenant-id : Tenant ID + --connector-id : Connector ID + --n : Number of documents (default 10) + --query : Search query + --doc-id : Document ID + --fields : Fields to update (JSON) + +Example: (gets docs for a given tenant id and connector id) + python vespa_debug_tool.py --action list_docs --tenant-id my_tenant --connector-id 1 --n 5 +""" +import argparse +import json +from typing import Any +from typing import Dict +from typing import List + +from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id +from onyx.db.engine import get_session_with_tenant +from onyx.db.search_settings import get_current_search_settings +from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client +from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT +from onyx.document_index.vespa_constants import SEARCH_ENDPOINT +from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL +from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT + + +# Print Vespa configuration URLs +def print_vespa_config() -> None: + print(f"Vespa Application Endpoint: {VESPA_APPLICATION_ENDPOINT}") + print(f"Vespa App Container URL: {VESPA_APP_CONTAINER_URL}") + print(f"Vespa Search Endpoint: {SEARCH_ENDPOINT}") + print(f"Vespa Document ID Endpoint: {DOCUMENT_ID_ENDPOINT}") + + +# Check connectivity to Vespa endpoints +def check_vespa_connectivity() -> None: + endpoints = [ + f"{VESPA_APPLICATION_ENDPOINT}/ApplicationStatus", + f"{VESPA_APPLICATION_ENDPOINT}/tenant", + f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/", + f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default", + ] + + for endpoint in endpoints: + try: + with get_vespa_http_client() as client: + response = client.get(endpoint) + print(f"Successfully connected to Vespa at {endpoint}") + print(f"Status code: {response.status_code}") + print(f"Response: {response.text[:200]}...") + except Exception as e: + print(f"Failed to connect to Vespa at {endpoint}: {str(e)}") + + print("Vespa connectivity check completed.") + + +# Get info about the default Vespa application +def get_vespa_info() -> Dict[str, Any]: + url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default" + with get_vespa_http_client() as client: + response = client.get(url) + response.raise_for_status() + return response.json() + + +# Get index name for a tenant and connector pair +def get_index_name(tenant_id: str, connector_id: int) -> str: + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + cc_pair = get_connector_credential_pair_from_id(db_session, connector_id) + if not cc_pair: + raise ValueError(f"No connector found for id {connector_id}") + search_settings = get_current_search_settings(db_session) + return search_settings.index_name if search_settings else "public" + + +# Perform a Vespa query using YQL syntax +def query_vespa(yql: str) -> List[Dict[str, Any]]: + params = { + "yql": yql, + "timeout": "10s", + } + with get_vespa_http_client() as client: + response = client.get(SEARCH_ENDPOINT, params=params) + response.raise_for_status() + return response.json()["root"]["children"] + + +# Get first N documents +def get_first_n_documents(n: int = 10) -> List[Dict[str, Any]]: + yql = f"select * from sources * where true limit {n};" + return query_vespa(yql) + + +# Pretty-print a list of documents +def print_documents(documents: List[Dict[str, Any]]) -> None: + for doc in documents: + print(json.dumps(doc, indent=2)) + print("-" * 80) + + +# Get and print documents for a specific tenant and connector +def get_documents_for_tenant_connector( + tenant_id: str, connector_id: int, n: int = 10 +) -> None: + get_index_name(tenant_id, connector_id) + documents = get_first_n_documents(n) + print(f"First {n} documents for tenant {tenant_id}, connector {connector_id}:") + print_documents(documents) + + +# Search documents for a specific tenant and connector +def search_documents( + tenant_id: str, connector_id: int, query: str, n: int = 10 +) -> None: + index_name = get_index_name(tenant_id, connector_id) + yql = f"select * from sources {index_name} where userInput(@query) limit {n};" + documents = query_vespa(yql) + print(f"Search results for query '{query}':") + print_documents(documents) + + +# Update a specific document +def update_document( + tenant_id: str, connector_id: int, doc_id: str, fields: Dict[str, Any] +) -> None: + index_name = get_index_name(tenant_id, connector_id) + url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}" + update_request = {"fields": {k: {"assign": v} for k, v in fields.items()}} + + with get_vespa_http_client() as client: + response = client.put(url, json=update_request) + response.raise_for_status() + print(f"Document {doc_id} updated successfully") + + +# Delete a specific document +def delete_document(tenant_id: str, connector_id: int, doc_id: str) -> None: + index_name = get_index_name(tenant_id, connector_id) + url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}" + + with get_vespa_http_client() as client: + response = client.delete(url) + response.raise_for_status() + print(f"Document {doc_id} deleted successfully") + + +# List documents from any source +def list_documents(n: int = 10) -> None: + yql = f"select * from sources * where true limit {n};" + url = f"{VESPA_APP_CONTAINER_URL}/search/" + params = { + "yql": yql, + "timeout": "10s", + } + try: + with get_vespa_http_client() as client: + response = client.get(url, params=params) + response.raise_for_status() + documents = response.json()["root"]["children"] + print(f"First {n} documents:") + print_documents(documents) + except Exception as e: + print(f"Failed to list documents: {str(e)}") + + +# Get and print ACLs for documents of a specific tenant and connector +def get_document_acls(tenant_id: str, connector_id: int, n: int = 10) -> None: + index_name = get_index_name(tenant_id, connector_id) + yql = f"select documentid, access_control_list from sources {index_name} where true limit {n};" + documents = query_vespa(yql) + print(f"ACLs for {n} documents from tenant {tenant_id}, connector {connector_id}:") + for doc in documents: + print(f"Document ID: {doc['fields']['documentid']}") + print( + f"ACL: {json.dumps(doc['fields'].get('access_control_list', {}), indent=2)}" + ) + print("-" * 80) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Vespa debugging tool") + parser.add_argument( + "--action", + choices=[ + "config", + "connect", + "list_docs", + "search", + "update", + "delete", + "get_acls", + ], + required=True, + help="Action to perform", + ) + parser.add_argument( + "--tenant-id", help="Tenant ID (for update, delete, and get_acls actions)" + ) + parser.add_argument( + "--connector-id", + type=int, + help="Connector ID (for update, delete, and get_acls actions)", + ) + parser.add_argument( + "--n", + type=int, + default=10, + help="Number of documents to retrieve (for list_docs, search, update, and get_acls actions)", + ) + parser.add_argument("--query", help="Search query (for search action)") + parser.add_argument("--doc-id", help="Document ID (for update and delete actions)") + parser.add_argument( + "--fields", help="Fields to update, in JSON format (for update action)" + ) + + args = parser.parse_args() + + if args.action == "config": + print_vespa_config() + elif args.action == "connect": + check_vespa_connectivity() + elif args.action == "list_docs": + # If tenant_id and connector_id are provided, list docs for that tenant/connector. + # Otherwise, list documents from any source. + if args.tenant_id and args.connector_id: + get_documents_for_tenant_connector( + args.tenant_id, args.connector_id, args.n + ) + else: + list_documents(args.n) + elif args.action == "search": + if not args.query: + parser.error("--query is required for search action") + search_documents(args.tenant_id, args.connector_id, args.query, args.n) + elif args.action == "update": + if not args.doc_id or not args.fields: + parser.error("--doc-id and --fields are required for update action") + fields = json.loads(args.fields) + update_document(args.tenant_id, args.connector_id, args.doc_id, fields) + elif args.action == "delete": + if not args.doc_id: + parser.error("--doc-id is required for delete action") + delete_document(args.tenant_id, args.connector_id, args.doc_id) + elif args.action == "get_acls": + if not args.tenant_id or args.connector_id is None: + parser.error( + "--tenant-id and --connector-id are required for get_acls action" + ) + get_document_acls(args.tenant_id, args.connector_id, args.n) + + +if __name__ == "__main__": + main()