mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-04 17:00:24 +02:00
* Fail instead of continuing if vespa cannot be reached within the timeout period * improve startup readability --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com>
173 lines
5.2 KiB
Python
173 lines
5.2 KiB
Python
import logging
|
|
import time
|
|
|
|
import psycopg2
|
|
import requests
|
|
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from danswer.configs.app_configs import POSTGRES_HOST
|
|
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
|
from danswer.configs.app_configs import POSTGRES_PORT
|
|
from danswer.configs.app_configs import POSTGRES_USER
|
|
from danswer.db.engine import build_connection_string
|
|
from danswer.db.engine import get_session_context_manager
|
|
from danswer.db.engine import SYNC_DB_API
|
|
from danswer.db.search_settings import get_current_search_settings
|
|
from danswer.db.swap_index import check_index_swap
|
|
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
|
from danswer.document_index.vespa.index import VespaIndex
|
|
from danswer.indexing.models import IndexingSetting
|
|
from danswer.main import setup_postgres
|
|
from danswer.main import setup_vespa
|
|
|
|
|
|
def _run_migrations(
|
|
database_url: str, direction: str = "upgrade", revision: str = "head"
|
|
) -> None:
|
|
# hide info logs emitted during migration
|
|
logging.getLogger("alembic").setLevel(logging.CRITICAL)
|
|
|
|
# Create an Alembic configuration object
|
|
alembic_cfg = Config("alembic.ini")
|
|
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
|
|
alembic_cfg.attributes["configure_logger"] = False
|
|
|
|
# Set the SQLAlchemy URL in the Alembic configuration
|
|
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
|
|
|
# Run the migration
|
|
if direction == "upgrade":
|
|
command.upgrade(alembic_cfg, revision)
|
|
elif direction == "downgrade":
|
|
command.downgrade(alembic_cfg, revision)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid direction: {direction}. Must be 'upgrade' or 'downgrade'."
|
|
)
|
|
|
|
logging.getLogger("alembic").setLevel(logging.INFO)
|
|
|
|
|
|
def reset_postgres(database: str = "postgres") -> None:
|
|
"""Reset the Postgres database."""
|
|
|
|
# NOTE: need to delete all rows to allow migrations to be rolled back
|
|
# as there are a few downgrades that don't properly handle data in tables
|
|
conn = psycopg2.connect(
|
|
dbname=database,
|
|
user=POSTGRES_USER,
|
|
password=POSTGRES_PASSWORD,
|
|
host=POSTGRES_HOST,
|
|
port=POSTGRES_PORT,
|
|
)
|
|
cur = conn.cursor()
|
|
|
|
# Disable triggers to prevent foreign key constraints from being checked
|
|
cur.execute("SET session_replication_role = 'replica';")
|
|
|
|
# Fetch all table names in the current database
|
|
cur.execute(
|
|
"""
|
|
SELECT tablename
|
|
FROM pg_tables
|
|
WHERE schemaname = 'public'
|
|
"""
|
|
)
|
|
|
|
tables = cur.fetchall()
|
|
|
|
for table in tables:
|
|
table_name = table[0]
|
|
|
|
# Don't touch migration history
|
|
if table_name == "alembic_version":
|
|
continue
|
|
|
|
# Don't touch Kombu
|
|
if table_name == "kombu_message" or table_name == "kombu_queue":
|
|
continue
|
|
|
|
cur.execute(f'DELETE FROM "{table_name}"')
|
|
|
|
# Re-enable triggers
|
|
cur.execute("SET session_replication_role = 'origin';")
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
conn.close()
|
|
|
|
# downgrade to base + upgrade back to head
|
|
conn_str = build_connection_string(
|
|
db=database,
|
|
user=POSTGRES_USER,
|
|
password=POSTGRES_PASSWORD,
|
|
host=POSTGRES_HOST,
|
|
port=POSTGRES_PORT,
|
|
db_api=SYNC_DB_API,
|
|
)
|
|
_run_migrations(
|
|
conn_str,
|
|
direction="downgrade",
|
|
revision="base",
|
|
)
|
|
_run_migrations(
|
|
conn_str,
|
|
direction="upgrade",
|
|
revision="head",
|
|
)
|
|
|
|
# do the same thing as we do on API server startup
|
|
with get_session_context_manager() as db_session:
|
|
setup_postgres(db_session)
|
|
|
|
|
|
def reset_vespa() -> None:
|
|
"""Wipe all data from the Vespa index."""
|
|
with get_session_context_manager() as db_session:
|
|
# swap to the correct default model
|
|
check_index_swap(db_session)
|
|
|
|
search_settings = get_current_search_settings(db_session)
|
|
index_name = search_settings.index_name
|
|
|
|
success = setup_vespa(
|
|
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
|
index_setting=IndexingSetting.from_db_model(search_settings),
|
|
secondary_index_setting=None,
|
|
)
|
|
if not success:
|
|
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
|
|
|
|
for _ in range(5):
|
|
try:
|
|
continuation = None
|
|
should_continue = True
|
|
while should_continue:
|
|
params = {"selection": "true", "cluster": "danswer_index"}
|
|
if continuation:
|
|
params = {**params, "continuation": continuation}
|
|
response = requests.delete(
|
|
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
|
|
)
|
|
response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
|
|
continuation = response_json.get("continuation")
|
|
should_continue = bool(continuation)
|
|
|
|
break
|
|
except Exception as e:
|
|
print(f"Error deleting documents: {e}")
|
|
time.sleep(5)
|
|
|
|
|
|
def reset_all() -> None:
|
|
"""Reset both Postgres and Vespa."""
|
|
print("Resetting Postgres...")
|
|
reset_postgres()
|
|
print("Resetting Vespa...")
|
|
reset_vespa()
|
|
print("Finished resetting all.")
|