mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-05 04:31:03 +02:00
Feature/db script (#4574)
* debug script + slight refactor of db class * better comments * move setup logger --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
parent
c9a609b7d8
commit
ea1d3c1eda
@ -55,8 +55,12 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
# why isn't this in configs?
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
# Global so we don't create more than one engine per process
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
@ -106,10 +110,10 @@ def build_connection_string(
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
use_iam: bool = USE_IAM_AUTH,
|
||||
use_iam_auth: bool = USE_IAM_AUTH,
|
||||
region: str = "us-west-2",
|
||||
) -> str:
|
||||
if use_iam:
|
||||
if use_iam_auth:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
||||
else:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
@ -176,9 +180,6 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
return result
|
||||
|
||||
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
def is_valid_schema_name(name: str) -> bool:
|
||||
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||
|
||||
@ -188,43 +189,44 @@ class SqlEngine:
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
)
|
||||
# NOTE(rkuo) - this appears to be unused, clean it up?
|
||||
# @classmethod
|
||||
# def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
# connection_string = build_connection_string(
|
||||
# db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
# )
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
# # Start with base kwargs that are valid for all pool types
|
||||
# final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
# if POSTGRES_USE_NULL_POOL:
|
||||
# # if null pool is specified, then we need to make sure that
|
||||
# # we remove any passed in kwargs related to pool size that would
|
||||
# # cause the initialization to fail
|
||||
# final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = 20
|
||||
final_engine_kwargs["max_overflow"] = 5
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
# final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
# if "pool_size" in final_engine_kwargs:
|
||||
# del final_engine_kwargs["pool_size"]
|
||||
# if "max_overflow" in final_engine_kwargs:
|
||||
# del final_engine_kwargs["max_overflow"]
|
||||
# else:
|
||||
# final_engine_kwargs["pool_size"] = 20
|
||||
# final_engine_kwargs["max_overflow"] = 5
|
||||
# final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
# final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
# # any passed in kwargs override the defaults
|
||||
# final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
# logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# # echo=True here for inspecting all emitted db queries
|
||||
# engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
# if USE_IAM_AUTH:
|
||||
# event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
return engine
|
||||
# return engine
|
||||
|
||||
@classmethod
|
||||
def init_engine(
|
||||
@ -232,20 +234,29 @@ class SqlEngine:
|
||||
pool_size: int,
|
||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||
max_overflow: int,
|
||||
app_name: str | None = None,
|
||||
db_api: str = SYNC_DB_API,
|
||||
use_iam: bool = USE_IAM_AUTH,
|
||||
connection_string: str | None = None,
|
||||
**extra_engine_kwargs: Any,
|
||||
) -> None:
|
||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||
important args, and if incorrectly specified, we have run into hitting the pool
|
||||
limit / using too many connections and overwhelming the database."""
|
||||
limit / using too many connections and overwhelming the database.
|
||||
|
||||
Specifying connection_string directly will cause some of the other parameters
|
||||
to be ignored.
|
||||
"""
|
||||
with cls._lock:
|
||||
if cls._engine:
|
||||
return
|
||||
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API,
|
||||
app_name=cls._app_name + "_sync",
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
if not connection_string:
|
||||
connection_string = build_connection_string(
|
||||
db_api=db_api,
|
||||
app_name=cls._app_name + "_sync",
|
||||
use_iam_auth=use_iam,
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
@ -274,7 +285,7 @@ class SqlEngine:
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
if use_iam:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
cls._engine = engine
|
||||
@ -306,6 +317,8 @@ class SqlEngine:
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
tenant_ids: list[str]
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [POSTGRES_DEFAULT_SCHEMA]
|
||||
|
||||
@ -354,7 +367,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
app_name = SqlEngine.get_app_name() + "_async"
|
||||
connection_string = build_connection_string(
|
||||
db_api=ASYNC_DB_API,
|
||||
use_iam=USE_IAM_AUTH,
|
||||
use_iam_auth=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
connect_args: dict[str, Any] = {}
|
||||
|
149
backend/scripts/debugging/onyx_db.py
Normal file
149
backend/scripts/debugging/onyx_db.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Onyx Database tool"""
|
||||
|
||||
import os
|
||||
|
||||
# hack to work around excessive use of globals in other functions
|
||||
os.environ["MULTI_TENANT"] = "True"
|
||||
|
||||
if True: # noqa: E402
|
||||
import csv
|
||||
import argparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
|
||||
from onyx.db.engine import (
|
||||
SYNC_DB_API,
|
||||
USE_IAM_AUTH,
|
||||
build_connection_string,
|
||||
get_all_tenant_ids,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
import heapq
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class TenantMetadata(BaseModel):
|
||||
num_docs: int
|
||||
num_chunks: int
|
||||
|
||||
|
||||
class SQLAlchemyDebugging:
|
||||
# Class for managing DB debugging actions.
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def top_chunks(self, k: int = 10) -> None:
|
||||
tenants_to_total_chunks: dict[str, TenantMetadata] = {}
|
||||
|
||||
logger.info("Fetching all tenant id's.")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
num_tenant_ids = len(tenant_ids)
|
||||
|
||||
logger.info(f"Found {num_tenant_ids} tenant id's.")
|
||||
|
||||
num_processed = 0
|
||||
for tenant_id in tenant_ids:
|
||||
num_processed += 1
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Calculate the total number of document rows for the current tenant
|
||||
total_documents = db_session.query(Document).count()
|
||||
# marginally useful to skip some tenants ... maybe we can improve on this
|
||||
# if total_documents < 100:
|
||||
# logger.info(f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
|
||||
# f"docs={total_documents} skip=True")
|
||||
# continue
|
||||
|
||||
# Calculate the sum of chunk_count for the current tenant
|
||||
# If there are no documents or all chunk_counts are NULL, sum will be None
|
||||
total_chunks = db_session.query(
|
||||
func.sum(Document.chunk_count)
|
||||
).scalar()
|
||||
total_chunks = total_chunks or 0
|
||||
|
||||
logger.info(
|
||||
f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
|
||||
f"docs={total_documents} chunks={total_chunks}"
|
||||
)
|
||||
|
||||
tenants_to_total_chunks[tenant_id] = TenantMetadata(
|
||||
num_docs=total_documents, num_chunks=total_chunks
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tenant '{tenant_id}': {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
# sort all by docs and dump to csv
|
||||
sorted_tenants = sorted(
|
||||
tenants_to_total_chunks.items(),
|
||||
key=lambda x: (x[1].num_chunks, x[1].num_docs),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
csv_filename = "tenants_by_num_docs.csv"
|
||||
with open(csv_filename, "w") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header
|
||||
# Write data rows (using the sorted list)
|
||||
for tenant_id, metadata in sorted_tenants:
|
||||
writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks])
|
||||
logger.info(f"Successfully wrote statistics to {csv_filename}")
|
||||
|
||||
# output top k by chunks
|
||||
top_k_tenants = heapq.nlargest(
|
||||
k, tenants_to_total_chunks.items(), key=lambda x: x[1].num_docs
|
||||
)
|
||||
|
||||
logger.info(f"Top {k} tenants by total chunks: {top_k_tenants}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Database/SQL debugging tool")
|
||||
parser.add_argument("--username", help="Database username", default="postgres")
|
||||
parser.add_argument("--password", help="Database password", required=True)
|
||||
parser.add_argument("--host", help="Database host", default="localhost")
|
||||
parser.add_argument("--port", help="Database port", default=5432)
|
||||
parser.add_argument("--db", help="Database default db name", default="danswer")
|
||||
|
||||
parser.add_argument("--report", help="Generate the given report")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"{args}")
|
||||
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API,
|
||||
app_name="onyx_db_sync",
|
||||
use_iam_auth=USE_IAM_AUTH,
|
||||
user=args.username,
|
||||
password=args.password,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
db=args.db,
|
||||
)
|
||||
|
||||
SqlEngine.init_engine(
|
||||
pool_size=20, max_overflow=5, connection_string=connection_string
|
||||
)
|
||||
|
||||
debugger = SQLAlchemyDebugging()
|
||||
|
||||
if args.report == "top-chunks":
|
||||
debugger.top_chunks(10)
|
||||
else:
|
||||
logger.info("No action.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user