Feature/db script (#4574)

* debug script + slight refactor of db class

* better comments

* move setup logger

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer 2025-04-23 13:00:35 -07:00 committed by GitHub
parent c9a609b7d8
commit ea1d3c1eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 205 additions and 43 deletions

View File

@ -55,8 +55,12 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
# why isn't this in configs?
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
@ -106,10 +110,10 @@ def build_connection_string(
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH,
use_iam_auth: bool = USE_IAM_AUTH,
region: str = "us-west-2",
) -> str:
if use_iam:
if use_iam_auth:
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
@ -176,9 +180,6 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None
@ -188,43 +189,44 @@ class SqlEngine:
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
)
# NOTE(rkuo) - this appears to be unused, clean it up?
# @classmethod
# def _init_engine(cls, **engine_kwargs: Any) -> Engine:
# connection_string = build_connection_string(
# db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
# )
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
# # Start with base kwargs that are valid for all pool types
# final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(engine_kwargs)
# if POSTGRES_USE_NULL_POOL:
# # if null pool is specified, then we need to make sure that
# # we remove any passed in kwargs related to pool size that would
# # cause the initialization to fail
# final_engine_kwargs.update(engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = 20
final_engine_kwargs["max_overflow"] = 5
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# final_engine_kwargs["poolclass"] = pool.NullPool
# if "pool_size" in final_engine_kwargs:
# del final_engine_kwargs["pool_size"]
# if "max_overflow" in final_engine_kwargs:
# del final_engine_kwargs["max_overflow"]
# else:
# final_engine_kwargs["pool_size"] = 20
# final_engine_kwargs["max_overflow"] = 5
# final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
# final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(engine_kwargs)
# # any passed in kwargs override the defaults
# final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
# logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# # echo=True here for inspecting all emitted db queries
# engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
# if USE_IAM_AUTH:
# event.listen(engine, "do_connect", provide_iam_token)
return engine
# return engine
@classmethod
def init_engine(
@ -232,20 +234,29 @@ class SqlEngine:
pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int,
app_name: str | None = None,
db_api: str = SYNC_DB_API,
use_iam: bool = USE_IAM_AUTH,
connection_string: str | None = None,
**extra_engine_kwargs: Any,
) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database."""
limit / using too many connections and overwhelming the database.
Specifying connection_string directly will cause some of the other parameters
to be ignored.
"""
with cls._lock:
if cls._engine:
return
connection_string = build_connection_string(
db_api=SYNC_DB_API,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
if not connection_string:
connection_string = build_connection_string(
db_api=db_api,
app_name=cls._app_name + "_sync",
use_iam_auth=use_iam,
)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
@ -274,7 +285,7 @@ class SqlEngine:
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
if use_iam:
event.listen(engine, "do_connect", provide_iam_token)
cls._engine = engine
@ -306,6 +317,8 @@ class SqlEngine:
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
tenant_ids: list[str]
if not MULTI_TENANT:
return [POSTGRES_DEFAULT_SCHEMA]
@ -354,7 +367,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
use_iam=USE_IAM_AUTH,
use_iam_auth=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}

View File

@ -0,0 +1,149 @@
"""Onyx Database tool"""
import os
# hack to work around excessive use of globals in other functions
os.environ["MULTI_TENANT"] = "True"
if True: # noqa: E402
import csv
import argparse
from pydantic import BaseModel
from sqlalchemy import func
from onyx.db.engine import (
SYNC_DB_API,
USE_IAM_AUTH,
build_connection_string,
get_all_tenant_ids,
)
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import Document
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
import heapq
logger = setup_logger()
class TenantMetadata(BaseModel):
num_docs: int
num_chunks: int
class SQLAlchemyDebugging:
# Class for managing DB debugging actions.
def __init__(self) -> None:
pass
def top_chunks(self, k: int = 10) -> None:
tenants_to_total_chunks: dict[str, TenantMetadata] = {}
logger.info("Fetching all tenant id's.")
tenant_ids = get_all_tenant_ids()
num_tenant_ids = len(tenant_ids)
logger.info(f"Found {num_tenant_ids} tenant id's.")
num_processed = 0
for tenant_id in tenant_ids:
num_processed += 1
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Calculate the total number of document rows for the current tenant
total_documents = db_session.query(Document).count()
# marginally useful to skip some tenants ... maybe we can improve on this
# if total_documents < 100:
# logger.info(f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
# f"docs={total_documents} skip=True")
# continue
# Calculate the sum of chunk_count for the current tenant
# If there are no documents or all chunk_counts are NULL, sum will be None
total_chunks = db_session.query(
func.sum(Document.chunk_count)
).scalar()
total_chunks = total_chunks or 0
logger.info(
f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
f"docs={total_documents} chunks={total_chunks}"
)
tenants_to_total_chunks[tenant_id] = TenantMetadata(
num_docs=total_documents, num_chunks=total_chunks
)
except Exception as e:
logger.error(f"Error processing tenant '{tenant_id}': {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# sort all by docs and dump to csv
sorted_tenants = sorted(
tenants_to_total_chunks.items(),
key=lambda x: (x[1].num_chunks, x[1].num_docs),
reverse=True,
)
csv_filename = "tenants_by_num_docs.csv"
with open(csv_filename, "w") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header
# Write data rows (using the sorted list)
for tenant_id, metadata in sorted_tenants:
writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks])
logger.info(f"Successfully wrote statistics to {csv_filename}")
# output top k by chunks
top_k_tenants = heapq.nlargest(
k, tenants_to_total_chunks.items(), key=lambda x: x[1].num_docs
)
logger.info(f"Top {k} tenants by total chunks: {top_k_tenants}")
def main() -> None:
parser = argparse.ArgumentParser(description="Database/SQL debugging tool")
parser.add_argument("--username", help="Database username", default="postgres")
parser.add_argument("--password", help="Database password", required=True)
parser.add_argument("--host", help="Database host", default="localhost")
parser.add_argument("--port", help="Database port", default=5432)
parser.add_argument("--db", help="Database default db name", default="danswer")
parser.add_argument("--report", help="Generate the given report")
args = parser.parse_args()
logger.info(f"{args}")
connection_string = build_connection_string(
db_api=SYNC_DB_API,
app_name="onyx_db_sync",
use_iam_auth=USE_IAM_AUTH,
user=args.username,
password=args.password,
host=args.host,
port=args.port,
db=args.db,
)
SqlEngine.init_engine(
pool_size=20, max_overflow=5, connection_string=connection_string
)
debugger = SQLAlchemyDebugging()
if args.report == "top-chunks":
debugger.top_chunks(10)
else:
logger.info("No action.")
if __name__ == "__main__":
main()