Feature/db script (#4574)

* debug script + slight refactor of db class

* better comments

* move setup logger

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer 2025-04-23 13:00:35 -07:00 committed by GitHub
parent c9a609b7d8
commit ea1d3c1eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 205 additions and 43 deletions

View File

@ -55,8 +55,12 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2" SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg" ASYNC_DB_API = "asyncpg"
# why isn't this in configs?
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true" USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
# Global so we don't create more than one engine per process # Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None _ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None SessionFactory: sessionmaker[Session] | None = None
@ -106,10 +110,10 @@ def build_connection_string(
port: str = POSTGRES_PORT, port: str = POSTGRES_PORT,
db: str = POSTGRES_DB, db: str = POSTGRES_DB,
app_name: str | None = None, app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH, use_iam_auth: bool = USE_IAM_AUTH,
region: str = "us-west-2", region: str = "us-west-2",
) -> str: ) -> str:
if use_iam: if use_iam_auth:
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}" base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
else: else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}" base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
@ -176,9 +180,6 @@ def get_db_current_time(db_session: Session) -> datetime:
return result return result
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
def is_valid_schema_name(name: str) -> bool: def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None return SCHEMA_NAME_REGEX.match(name) is not None
@ -188,43 +189,44 @@ class SqlEngine:
_lock: threading.Lock = threading.Lock() _lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME _app_name: str = POSTGRES_UNKNOWN_APP_NAME
@classmethod # NOTE(rkuo) - this appears to be unused, clean it up?
def _init_engine(cls, **engine_kwargs: Any) -> Engine: # @classmethod
connection_string = build_connection_string( # def _init_engine(cls, **engine_kwargs: Any) -> Engine:
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH # connection_string = build_connection_string(
) # db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
# )
# Start with base kwargs that are valid for all pool types # # Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {} # final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL: # if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that # # if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would # # we remove any passed in kwargs related to pool size that would
# cause the initialization to fail # # cause the initialization to fail
final_engine_kwargs.update(engine_kwargs) # final_engine_kwargs.update(engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool # final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs: # if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"] # del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs: # if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"] # del final_engine_kwargs["max_overflow"]
else: # else:
final_engine_kwargs["pool_size"] = 20 # final_engine_kwargs["pool_size"] = 20
final_engine_kwargs["max_overflow"] = 5 # final_engine_kwargs["max_overflow"] = 5
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING # final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE # final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults # # any passed in kwargs override the defaults
final_engine_kwargs.update(engine_kwargs) # final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}") # logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries # # echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs) # engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH: # if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token) # event.listen(engine, "do_connect", provide_iam_token)
return engine # return engine
@classmethod @classmethod
def init_engine( def init_engine(
@ -232,20 +234,29 @@ class SqlEngine:
pool_size: int, pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy # is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int, max_overflow: int,
app_name: str | None = None,
db_api: str = SYNC_DB_API,
use_iam: bool = USE_IAM_AUTH,
connection_string: str | None = None,
**extra_engine_kwargs: Any, **extra_engine_kwargs: Any,
) -> None: ) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are """NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database.""" limit / using too many connections and overwhelming the database.
Specifying connection_string directly will cause some of the other parameters
to be ignored.
"""
with cls._lock: with cls._lock:
if cls._engine: if cls._engine:
return return
connection_string = build_connection_string( if not connection_string:
db_api=SYNC_DB_API, connection_string = build_connection_string(
app_name=cls._app_name + "_sync", db_api=db_api,
use_iam=USE_IAM_AUTH, app_name=cls._app_name + "_sync",
) use_iam_auth=use_iam,
)
# Start with base kwargs that are valid for all pool types # Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {} final_engine_kwargs: dict[str, Any] = {}
@ -274,7 +285,7 @@ class SqlEngine:
# echo=True here for inspecting all emitted db queries # echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs) engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH: if use_iam:
event.listen(engine, "do_connect", provide_iam_token) event.listen(engine, "do_connect", provide_iam_token)
cls._engine = engine cls._engine = engine
@ -306,6 +317,8 @@ class SqlEngine:
def get_all_tenant_ids() -> list[str]: def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant.""" """Returning [None] means the only tenant is the 'public' or self hosted tenant."""
tenant_ids: list[str]
if not MULTI_TENANT: if not MULTI_TENANT:
return [POSTGRES_DEFAULT_SCHEMA] return [POSTGRES_DEFAULT_SCHEMA]
@ -354,7 +367,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
app_name = SqlEngine.get_app_name() + "_async" app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string( connection_string = build_connection_string(
db_api=ASYNC_DB_API, db_api=ASYNC_DB_API,
use_iam=USE_IAM_AUTH, use_iam_auth=USE_IAM_AUTH,
) )
connect_args: dict[str, Any] = {} connect_args: dict[str, Any] = {}

View File

@ -0,0 +1,149 @@
"""Onyx Database tool"""
import os
# hack to work around excessive use of globals in other functions
os.environ["MULTI_TENANT"] = "True"
if True: # noqa: E402
import csv
import argparse
from pydantic import BaseModel
from sqlalchemy import func
from onyx.db.engine import (
SYNC_DB_API,
USE_IAM_AUTH,
build_connection_string,
get_all_tenant_ids,
)
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import Document
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
import heapq
logger = setup_logger()
class TenantMetadata(BaseModel):
num_docs: int
num_chunks: int
class SQLAlchemyDebugging:
# Class for managing DB debugging actions.
def __init__(self) -> None:
pass
def top_chunks(self, k: int = 10) -> None:
tenants_to_total_chunks: dict[str, TenantMetadata] = {}
logger.info("Fetching all tenant id's.")
tenant_ids = get_all_tenant_ids()
num_tenant_ids = len(tenant_ids)
logger.info(f"Found {num_tenant_ids} tenant id's.")
num_processed = 0
for tenant_id in tenant_ids:
num_processed += 1
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Calculate the total number of document rows for the current tenant
total_documents = db_session.query(Document).count()
# marginally useful to skip some tenants ... maybe we can improve on this
# if total_documents < 100:
# logger.info(f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
# f"docs={total_documents} skip=True")
# continue
# Calculate the sum of chunk_count for the current tenant
# If there are no documents or all chunk_counts are NULL, sum will be None
total_chunks = db_session.query(
func.sum(Document.chunk_count)
).scalar()
total_chunks = total_chunks or 0
logger.info(
f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
f"docs={total_documents} chunks={total_chunks}"
)
tenants_to_total_chunks[tenant_id] = TenantMetadata(
num_docs=total_documents, num_chunks=total_chunks
)
except Exception as e:
logger.error(f"Error processing tenant '{tenant_id}': {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# sort all by docs and dump to csv
sorted_tenants = sorted(
tenants_to_total_chunks.items(),
key=lambda x: (x[1].num_chunks, x[1].num_docs),
reverse=True,
)
csv_filename = "tenants_by_num_docs.csv"
with open(csv_filename, "w") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header
# Write data rows (using the sorted list)
for tenant_id, metadata in sorted_tenants:
writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks])
logger.info(f"Successfully wrote statistics to {csv_filename}")
# output top k by chunks
top_k_tenants = heapq.nlargest(
k, tenants_to_total_chunks.items(), key=lambda x: x[1].num_docs
)
logger.info(f"Top {k} tenants by total chunks: {top_k_tenants}")
def main() -> None:
parser = argparse.ArgumentParser(description="Database/SQL debugging tool")
parser.add_argument("--username", help="Database username", default="postgres")
parser.add_argument("--password", help="Database password", required=True)
parser.add_argument("--host", help="Database host", default="localhost")
parser.add_argument("--port", help="Database port", default=5432)
parser.add_argument("--db", help="Database default db name", default="danswer")
parser.add_argument("--report", help="Generate the given report")
args = parser.parse_args()
logger.info(f"{args}")
connection_string = build_connection_string(
db_api=SYNC_DB_API,
app_name="onyx_db_sync",
use_iam_auth=USE_IAM_AUTH,
user=args.username,
password=args.password,
host=args.host,
port=args.port,
db=args.db,
)
SqlEngine.init_engine(
pool_size=20, max_overflow=5, connection_string=connection_string
)
debugger = SQLAlchemyDebugging()
if args.report == "top-chunks":
debugger.top_chunks(10)
else:
logger.info("No action.")
if __name__ == "__main__":
main()