mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-19 12:30:55 +02:00
Feature/db script (#4574)
* debug script + slight refactor of db class * better comments * move setup logger --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
parent
c9a609b7d8
commit
ea1d3c1eda
@ -55,8 +55,12 @@ logger = setup_logger()
|
|||||||
SYNC_DB_API = "psycopg2"
|
SYNC_DB_API = "psycopg2"
|
||||||
ASYNC_DB_API = "asyncpg"
|
ASYNC_DB_API = "asyncpg"
|
||||||
|
|
||||||
|
# why isn't this in configs?
|
||||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||||
|
|
||||||
|
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||||
|
|
||||||
|
|
||||||
# Global so we don't create more than one engine per process
|
# Global so we don't create more than one engine per process
|
||||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||||
SessionFactory: sessionmaker[Session] | None = None
|
SessionFactory: sessionmaker[Session] | None = None
|
||||||
@ -106,10 +110,10 @@ def build_connection_string(
|
|||||||
port: str = POSTGRES_PORT,
|
port: str = POSTGRES_PORT,
|
||||||
db: str = POSTGRES_DB,
|
db: str = POSTGRES_DB,
|
||||||
app_name: str | None = None,
|
app_name: str | None = None,
|
||||||
use_iam: bool = USE_IAM_AUTH,
|
use_iam_auth: bool = USE_IAM_AUTH,
|
||||||
region: str = "us-west-2",
|
region: str = "us-west-2",
|
||||||
) -> str:
|
) -> str:
|
||||||
if use_iam:
|
if use_iam_auth:
|
||||||
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
||||||
else:
|
else:
|
||||||
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||||
@ -176,9 +180,6 @@ def get_db_current_time(db_session: Session) -> datetime:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_schema_name(name: str) -> bool:
|
def is_valid_schema_name(name: str) -> bool:
|
||||||
return SCHEMA_NAME_REGEX.match(name) is not None
|
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||||
|
|
||||||
@ -188,43 +189,44 @@ class SqlEngine:
|
|||||||
_lock: threading.Lock = threading.Lock()
|
_lock: threading.Lock = threading.Lock()
|
||||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||||
|
|
||||||
@classmethod
|
# NOTE(rkuo) - this appears to be unused, clean it up?
|
||||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
# @classmethod
|
||||||
connection_string = build_connection_string(
|
# def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
# connection_string = build_connection_string(
|
||||||
)
|
# db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||||
|
# )
|
||||||
|
|
||||||
# Start with base kwargs that are valid for all pool types
|
# # Start with base kwargs that are valid for all pool types
|
||||||
final_engine_kwargs: dict[str, Any] = {}
|
# final_engine_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
if POSTGRES_USE_NULL_POOL:
|
# if POSTGRES_USE_NULL_POOL:
|
||||||
# if null pool is specified, then we need to make sure that
|
# # if null pool is specified, then we need to make sure that
|
||||||
# we remove any passed in kwargs related to pool size that would
|
# # we remove any passed in kwargs related to pool size that would
|
||||||
# cause the initialization to fail
|
# # cause the initialization to fail
|
||||||
final_engine_kwargs.update(engine_kwargs)
|
# final_engine_kwargs.update(engine_kwargs)
|
||||||
|
|
||||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
# final_engine_kwargs["poolclass"] = pool.NullPool
|
||||||
if "pool_size" in final_engine_kwargs:
|
# if "pool_size" in final_engine_kwargs:
|
||||||
del final_engine_kwargs["pool_size"]
|
# del final_engine_kwargs["pool_size"]
|
||||||
if "max_overflow" in final_engine_kwargs:
|
# if "max_overflow" in final_engine_kwargs:
|
||||||
del final_engine_kwargs["max_overflow"]
|
# del final_engine_kwargs["max_overflow"]
|
||||||
else:
|
# else:
|
||||||
final_engine_kwargs["pool_size"] = 20
|
# final_engine_kwargs["pool_size"] = 20
|
||||||
final_engine_kwargs["max_overflow"] = 5
|
# final_engine_kwargs["max_overflow"] = 5
|
||||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
# final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
# final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||||
|
|
||||||
# any passed in kwargs override the defaults
|
# # any passed in kwargs override the defaults
|
||||||
final_engine_kwargs.update(engine_kwargs)
|
# final_engine_kwargs.update(engine_kwargs)
|
||||||
|
|
||||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
# logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||||
# echo=True here for inspecting all emitted db queries
|
# # echo=True here for inspecting all emitted db queries
|
||||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
# engine = create_engine(connection_string, **final_engine_kwargs)
|
||||||
|
|
||||||
if USE_IAM_AUTH:
|
# if USE_IAM_AUTH:
|
||||||
event.listen(engine, "do_connect", provide_iam_token)
|
# event.listen(engine, "do_connect", provide_iam_token)
|
||||||
|
|
||||||
return engine
|
# return engine
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_engine(
|
def init_engine(
|
||||||
@ -232,20 +234,29 @@ class SqlEngine:
|
|||||||
pool_size: int,
|
pool_size: int,
|
||||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||||
max_overflow: int,
|
max_overflow: int,
|
||||||
|
app_name: str | None = None,
|
||||||
|
db_api: str = SYNC_DB_API,
|
||||||
|
use_iam: bool = USE_IAM_AUTH,
|
||||||
|
connection_string: str | None = None,
|
||||||
**extra_engine_kwargs: Any,
|
**extra_engine_kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||||
important args, and if incorrectly specified, we have run into hitting the pool
|
important args, and if incorrectly specified, we have run into hitting the pool
|
||||||
limit / using too many connections and overwhelming the database."""
|
limit / using too many connections and overwhelming the database.
|
||||||
|
|
||||||
|
Specifying connection_string directly will cause some of the other parameters
|
||||||
|
to be ignored.
|
||||||
|
"""
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._engine:
|
if cls._engine:
|
||||||
return
|
return
|
||||||
|
|
||||||
connection_string = build_connection_string(
|
if not connection_string:
|
||||||
db_api=SYNC_DB_API,
|
connection_string = build_connection_string(
|
||||||
app_name=cls._app_name + "_sync",
|
db_api=db_api,
|
||||||
use_iam=USE_IAM_AUTH,
|
app_name=cls._app_name + "_sync",
|
||||||
)
|
use_iam_auth=use_iam,
|
||||||
|
)
|
||||||
|
|
||||||
# Start with base kwargs that are valid for all pool types
|
# Start with base kwargs that are valid for all pool types
|
||||||
final_engine_kwargs: dict[str, Any] = {}
|
final_engine_kwargs: dict[str, Any] = {}
|
||||||
@ -274,7 +285,7 @@ class SqlEngine:
|
|||||||
# echo=True here for inspecting all emitted db queries
|
# echo=True here for inspecting all emitted db queries
|
||||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||||
|
|
||||||
if USE_IAM_AUTH:
|
if use_iam:
|
||||||
event.listen(engine, "do_connect", provide_iam_token)
|
event.listen(engine, "do_connect", provide_iam_token)
|
||||||
|
|
||||||
cls._engine = engine
|
cls._engine = engine
|
||||||
@ -306,6 +317,8 @@ class SqlEngine:
|
|||||||
def get_all_tenant_ids() -> list[str]:
|
def get_all_tenant_ids() -> list[str]:
|
||||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||||
|
|
||||||
|
tenant_ids: list[str]
|
||||||
|
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
return [POSTGRES_DEFAULT_SCHEMA]
|
return [POSTGRES_DEFAULT_SCHEMA]
|
||||||
|
|
||||||
@ -354,7 +367,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|||||||
app_name = SqlEngine.get_app_name() + "_async"
|
app_name = SqlEngine.get_app_name() + "_async"
|
||||||
connection_string = build_connection_string(
|
connection_string = build_connection_string(
|
||||||
db_api=ASYNC_DB_API,
|
db_api=ASYNC_DB_API,
|
||||||
use_iam=USE_IAM_AUTH,
|
use_iam_auth=USE_IAM_AUTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
connect_args: dict[str, Any] = {}
|
connect_args: dict[str, Any] = {}
|
||||||
|
149
backend/scripts/debugging/onyx_db.py
Normal file
149
backend/scripts/debugging/onyx_db.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
"""Onyx Database tool"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# hack to work around excessive use of globals in other functions
|
||||||
|
os.environ["MULTI_TENANT"] = "True"
|
||||||
|
|
||||||
|
if True: # noqa: E402
|
||||||
|
import csv
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from onyx.db.engine import (
|
||||||
|
SYNC_DB_API,
|
||||||
|
USE_IAM_AUTH,
|
||||||
|
build_connection_string,
|
||||||
|
get_all_tenant_ids,
|
||||||
|
)
|
||||||
|
from onyx.db.engine import get_session_with_tenant
|
||||||
|
from onyx.db.engine import SqlEngine
|
||||||
|
from onyx.db.models import Document
|
||||||
|
from onyx.utils.logger import setup_logger
|
||||||
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||||
|
|
||||||
|
import heapq
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class TenantMetadata(BaseModel):
|
||||||
|
num_docs: int
|
||||||
|
num_chunks: int
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAlchemyDebugging:
|
||||||
|
# Class for managing DB debugging actions.
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def top_chunks(self, k: int = 10) -> None:
|
||||||
|
tenants_to_total_chunks: dict[str, TenantMetadata] = {}
|
||||||
|
|
||||||
|
logger.info("Fetching all tenant id's.")
|
||||||
|
tenant_ids = get_all_tenant_ids()
|
||||||
|
num_tenant_ids = len(tenant_ids)
|
||||||
|
|
||||||
|
logger.info(f"Found {num_tenant_ids} tenant id's.")
|
||||||
|
|
||||||
|
num_processed = 0
|
||||||
|
for tenant_id in tenant_ids:
|
||||||
|
num_processed += 1
|
||||||
|
|
||||||
|
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||||
|
# Calculate the total number of document rows for the current tenant
|
||||||
|
total_documents = db_session.query(Document).count()
|
||||||
|
# marginally useful to skip some tenants ... maybe we can improve on this
|
||||||
|
# if total_documents < 100:
|
||||||
|
# logger.info(f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
|
||||||
|
# f"docs={total_documents} skip=True")
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# Calculate the sum of chunk_count for the current tenant
|
||||||
|
# If there are no documents or all chunk_counts are NULL, sum will be None
|
||||||
|
total_chunks = db_session.query(
|
||||||
|
func.sum(Document.chunk_count)
|
||||||
|
).scalar()
|
||||||
|
total_chunks = total_chunks or 0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
|
||||||
|
f"docs={total_documents} chunks={total_chunks}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tenants_to_total_chunks[tenant_id] = TenantMetadata(
|
||||||
|
num_docs=total_documents, num_chunks=total_chunks
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing tenant '{tenant_id}': {e}")
|
||||||
|
finally:
|
||||||
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||||
|
|
||||||
|
# sort all by docs and dump to csv
|
||||||
|
sorted_tenants = sorted(
|
||||||
|
tenants_to_total_chunks.items(),
|
||||||
|
key=lambda x: (x[1].num_chunks, x[1].num_docs),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
csv_filename = "tenants_by_num_docs.csv"
|
||||||
|
with open(csv_filename, "w") as csvfile:
|
||||||
|
writer = csv.writer(csvfile)
|
||||||
|
writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header
|
||||||
|
# Write data rows (using the sorted list)
|
||||||
|
for tenant_id, metadata in sorted_tenants:
|
||||||
|
writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks])
|
||||||
|
logger.info(f"Successfully wrote statistics to {csv_filename}")
|
||||||
|
|
||||||
|
# output top k by chunks
|
||||||
|
top_k_tenants = heapq.nlargest(
|
||||||
|
k, tenants_to_total_chunks.items(), key=lambda x: x[1].num_docs
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Top {k} tenants by total chunks: {top_k_tenants}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description="Database/SQL debugging tool")
|
||||||
|
parser.add_argument("--username", help="Database username", default="postgres")
|
||||||
|
parser.add_argument("--password", help="Database password", required=True)
|
||||||
|
parser.add_argument("--host", help="Database host", default="localhost")
|
||||||
|
parser.add_argument("--port", help="Database port", default=5432)
|
||||||
|
parser.add_argument("--db", help="Database default db name", default="danswer")
|
||||||
|
|
||||||
|
parser.add_argument("--report", help="Generate the given report")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info(f"{args}")
|
||||||
|
|
||||||
|
connection_string = build_connection_string(
|
||||||
|
db_api=SYNC_DB_API,
|
||||||
|
app_name="onyx_db_sync",
|
||||||
|
use_iam_auth=USE_IAM_AUTH,
|
||||||
|
user=args.username,
|
||||||
|
password=args.password,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
db=args.db,
|
||||||
|
)
|
||||||
|
|
||||||
|
SqlEngine.init_engine(
|
||||||
|
pool_size=20, max_overflow=5, connection_string=connection_string
|
||||||
|
)
|
||||||
|
|
||||||
|
debugger = SQLAlchemyDebugging()
|
||||||
|
|
||||||
|
if args.report == "top-chunks":
|
||||||
|
debugger.top_chunks(10)
|
||||||
|
else:
|
||||||
|
logger.info("No action.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user