diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index 41bde426e35..a6ad19bbc42 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -55,8 +55,12 @@ logger = setup_logger() SYNC_DB_API = "psycopg2" ASYNC_DB_API = "asyncpg" +# why isn't this in configs? USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true" +SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$") + + # Global so we don't create more than one engine per process _ASYNC_ENGINE: AsyncEngine | None = None SessionFactory: sessionmaker[Session] | None = None @@ -106,10 +110,10 @@ def build_connection_string( port: str = POSTGRES_PORT, db: str = POSTGRES_DB, app_name: str | None = None, - use_iam: bool = USE_IAM_AUTH, + use_iam_auth: bool = USE_IAM_AUTH, region: str = "us-west-2", ) -> str: - if use_iam: + if use_iam_auth: base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}" else: base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}" @@ -176,9 +180,6 @@ def get_db_current_time(db_session: Session) -> datetime: return result -SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$") - - def is_valid_schema_name(name: str) -> bool: return SCHEMA_NAME_REGEX.match(name) is not None @@ -188,43 +189,44 @@ class SqlEngine: _lock: threading.Lock = threading.Lock() _app_name: str = POSTGRES_UNKNOWN_APP_NAME - @classmethod - def _init_engine(cls, **engine_kwargs: Any) -> Engine: - connection_string = build_connection_string( - db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH - ) + # NOTE(rkuo) - this appears to be unused, clean it up? + # @classmethod + # def _init_engine(cls, **engine_kwargs: Any) -> Engine: + # connection_string = build_connection_string( + # db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH + # ) - # Start with base kwargs that are valid for all pool types - final_engine_kwargs: dict[str, Any] = {} + # # Start with base kwargs that are valid for all pool types + # final_engine_kwargs: dict[str, Any] = {} - if POSTGRES_USE_NULL_POOL: - # if null pool is specified, then we need to make sure that - # we remove any passed in kwargs related to pool size that would - # cause the initialization to fail - final_engine_kwargs.update(engine_kwargs) + # if POSTGRES_USE_NULL_POOL: + # # if null pool is specified, then we need to make sure that + # # we remove any passed in kwargs related to pool size that would + # # cause the initialization to fail + # final_engine_kwargs.update(engine_kwargs) - final_engine_kwargs["poolclass"] = pool.NullPool - if "pool_size" in final_engine_kwargs: - del final_engine_kwargs["pool_size"] - if "max_overflow" in final_engine_kwargs: - del final_engine_kwargs["max_overflow"] - else: - final_engine_kwargs["pool_size"] = 20 - final_engine_kwargs["max_overflow"] = 5 - final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING - final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE + # final_engine_kwargs["poolclass"] = pool.NullPool + # if "pool_size" in final_engine_kwargs: + # del final_engine_kwargs["pool_size"] + # if "max_overflow" in final_engine_kwargs: + # del final_engine_kwargs["max_overflow"] + # else: + # final_engine_kwargs["pool_size"] = 20 + # final_engine_kwargs["max_overflow"] = 5 + # final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING + # final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE - # any passed in kwargs override the defaults - final_engine_kwargs.update(engine_kwargs) + # # any passed in kwargs override the defaults + # final_engine_kwargs.update(engine_kwargs) - logger.info(f"Creating engine with kwargs: {final_engine_kwargs}") - # echo=True here for inspecting all emitted db queries - engine = create_engine(connection_string, **final_engine_kwargs) + # logger.info(f"Creating engine with kwargs: {final_engine_kwargs}") + # # echo=True here for inspecting all emitted db queries + # engine = create_engine(connection_string, **final_engine_kwargs) - if USE_IAM_AUTH: - event.listen(engine, "do_connect", provide_iam_token) + # if USE_IAM_AUTH: + # event.listen(engine, "do_connect", provide_iam_token) - return engine + # return engine @classmethod def init_engine( @@ -232,20 +234,29 @@ class SqlEngine: pool_size: int, # is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy max_overflow: int, + app_name: str | None = None, + db_api: str = SYNC_DB_API, + use_iam: bool = USE_IAM_AUTH, + connection_string: str | None = None, **extra_engine_kwargs: Any, ) -> None: """NOTE: enforce that pool_size and pool_max_overflow are passed in. These are important args, and if incorrectly specified, we have run into hitting the pool - limit / using too many connections and overwhelming the database.""" + limit / using too many connections and overwhelming the database. + + Specifying connection_string directly will cause some of the other parameters + to be ignored. + """ with cls._lock: if cls._engine: return - connection_string = build_connection_string( - db_api=SYNC_DB_API, - app_name=cls._app_name + "_sync", - use_iam=USE_IAM_AUTH, - ) + if not connection_string: + connection_string = build_connection_string( + db_api=db_api, + app_name=cls._app_name + "_sync", + use_iam_auth=use_iam, + ) # Start with base kwargs that are valid for all pool types final_engine_kwargs: dict[str, Any] = {} @@ -274,7 +285,7 @@ class SqlEngine: # echo=True here for inspecting all emitted db queries engine = create_engine(connection_string, **final_engine_kwargs) - if USE_IAM_AUTH: + if use_iam: event.listen(engine, "do_connect", provide_iam_token) cls._engine = engine @@ -306,6 +317,8 @@ class SqlEngine: def get_all_tenant_ids() -> list[str]: """Returning [None] means the only tenant is the 'public' or self hosted tenant.""" + tenant_ids: list[str] + if not MULTI_TENANT: return [POSTGRES_DEFAULT_SCHEMA] @@ -354,7 +367,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: app_name = SqlEngine.get_app_name() + "_async" connection_string = build_connection_string( db_api=ASYNC_DB_API, - use_iam=USE_IAM_AUTH, + use_iam_auth=USE_IAM_AUTH, ) connect_args: dict[str, Any] = {} diff --git a/backend/scripts/debugging/onyx_db.py b/backend/scripts/debugging/onyx_db.py new file mode 100644 index 00000000000..4023449582c --- /dev/null +++ b/backend/scripts/debugging/onyx_db.py @@ -0,0 +1,149 @@ +"""Onyx Database tool""" + +import os + +# hack to work around excessive use of globals in other functions +os.environ["MULTI_TENANT"] = "True" + +if True: # noqa: E402 + import csv + import argparse + + from pydantic import BaseModel + from sqlalchemy import func + + from onyx.db.engine import ( + SYNC_DB_API, + USE_IAM_AUTH, + build_connection_string, + get_all_tenant_ids, + ) + from onyx.db.engine import get_session_with_tenant + from onyx.db.engine import SqlEngine + from onyx.db.models import Document + from onyx.utils.logger import setup_logger + from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR + + import heapq + + logger = setup_logger() + + +class TenantMetadata(BaseModel): + num_docs: int + num_chunks: int + + +class SQLAlchemyDebugging: + # Class for managing DB debugging actions. + def __init__(self) -> None: + pass + + def top_chunks(self, k: int = 10) -> None: + tenants_to_total_chunks: dict[str, TenantMetadata] = {} + + logger.info("Fetching all tenant id's.") + tenant_ids = get_all_tenant_ids() + num_tenant_ids = len(tenant_ids) + + logger.info(f"Found {num_tenant_ids} tenant id's.") + + num_processed = 0 + for tenant_id in tenant_ids: + num_processed += 1 + + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + try: + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + # Calculate the total number of document rows for the current tenant + total_documents = db_session.query(Document).count() + # marginally useful to skip some tenants ... maybe we can improve on this + # if total_documents < 100: + # logger.info(f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': " + # f"docs={total_documents} skip=True") + # continue + + # Calculate the sum of chunk_count for the current tenant + # If there are no documents or all chunk_counts are NULL, sum will be None + total_chunks = db_session.query( + func.sum(Document.chunk_count) + ).scalar() + total_chunks = total_chunks or 0 + + logger.info( + f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': " + f"docs={total_documents} chunks={total_chunks}" + ) + + tenants_to_total_chunks[tenant_id] = TenantMetadata( + num_docs=total_documents, num_chunks=total_chunks + ) + except Exception as e: + logger.error(f"Error processing tenant '{tenant_id}': {e}") + finally: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + # sort all by docs and dump to csv + sorted_tenants = sorted( + tenants_to_total_chunks.items(), + key=lambda x: (x[1].num_chunks, x[1].num_docs), + reverse=True, + ) + + csv_filename = "tenants_by_num_docs.csv" + with open(csv_filename, "w") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header + # Write data rows (using the sorted list) + for tenant_id, metadata in sorted_tenants: + writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks]) + logger.info(f"Successfully wrote statistics to {csv_filename}") + + # output top k by chunks + top_k_tenants = heapq.nlargest( + k, tenants_to_total_chunks.items(), key=lambda x: x[1].num_docs + ) + + logger.info(f"Top {k} tenants by total chunks: {top_k_tenants}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Database/SQL debugging tool") + parser.add_argument("--username", help="Database username", default="postgres") + parser.add_argument("--password", help="Database password", required=True) + parser.add_argument("--host", help="Database host", default="localhost") + parser.add_argument("--port", help="Database port", default=5432) + parser.add_argument("--db", help="Database default db name", default="danswer") + + parser.add_argument("--report", help="Generate the given report") + + args = parser.parse_args() + + logger.info(f"{args}") + + connection_string = build_connection_string( + db_api=SYNC_DB_API, + app_name="onyx_db_sync", + use_iam_auth=USE_IAM_AUTH, + user=args.username, + password=args.password, + host=args.host, + port=args.port, + db=args.db, + ) + + SqlEngine.init_engine( + pool_size=20, max_overflow=5, connection_string=connection_string + ) + + debugger = SQLAlchemyDebugging() + + if args.report == "top-chunks": + debugger.top_chunks(10) + else: + logger.info("No action.") + + +if __name__ == "__main__": + main()