# This file is purely for development use, not included in any builds import argparse import json import os import subprocess from datetime import datetime import requests from alembic import command from alembic.config import Config from danswer.configs.app_configs import POSTGRES_DB from danswer.configs.app_configs import POSTGRES_HOST from danswer.configs.app_configs import POSTGRES_PASSWORD from danswer.configs.app_configs import POSTGRES_PORT from danswer.configs.app_configs import POSTGRES_USER from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.app_configs import QDRANT_HOST from danswer.configs.app_configs import QDRANT_PORT from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION from danswer.datastores.qdrant.indexing import create_qdrant_collection from danswer.datastores.qdrant.indexing import list_qdrant_collections from danswer.datastores.typesense.store import create_typesense_collection from danswer.utils.clients import get_qdrant_client from danswer.utils.clients import get_typesense_client from danswer.utils.logging import setup_logger from qdrant_client.http.models.models import SnapshotDescription from typesense.exceptions import ObjectNotFound # type: ignore logger = setup_logger() def save_postgres(filename: str) -> None: logger.info("Attempting to take Postgres snapshot") cmd = f"pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB} > {filename}" subprocess.run( cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True ) def load_postgres(filename: str) -> None: logger.info("Attempting to load Postgres snapshot") try: alembic_cfg = Config("alembic.ini") command.upgrade(alembic_cfg, "head") except Exception as e: logger.info("Alembic upgrade failed, maybe already has run") cmd = f"pg_restore --clean -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -d {POSTGRES_DB} -1 {filename}" subprocess.run( cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True ) def snapshot_time_compare(snap: SnapshotDescription) -> datetime: if not hasattr(snap, "creation_time") or snap.creation_time is None: raise RuntimeError("Qdrant Snapshots Failed") return datetime.strptime(snap.creation_time, "%Y-%m-%dT%H:%M:%S") def save_qdrant(filename: str) -> None: logger.info("Attempting to take Qdrant snapshot") qdrant_client = get_qdrant_client() qdrant_client.create_snapshot(collection_name=QDRANT_DEFAULT_COLLECTION) snapshots = qdrant_client.list_snapshots(collection_name=QDRANT_DEFAULT_COLLECTION) valid_snapshots = [snap for snap in snapshots if snap.creation_time is not None] sorted_snapshots = sorted(valid_snapshots, key=snapshot_time_compare) last_snapshot_name = sorted_snapshots[-1].name url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/{last_snapshot_name}" response = requests.get(url, stream=True) if response.status_code != 200: raise RuntimeError("Qdrant Save Failed") with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) def load_qdrant(filename: str) -> None: logger.info("Attempting to load Qdrant snapshot") if QDRANT_DEFAULT_COLLECTION not in { collection.name for collection in list_qdrant_collections().collections }: create_qdrant_collection(QDRANT_DEFAULT_COLLECTION) snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/" with open(filename, "rb") as f: files = {"snapshot": (os.path.basename(filename), f)} response = requests.post(snapshot_url + "upload", files=files) if response.status_code != 200: raise RuntimeError("Qdrant Snapshot Upload Failed") data = {"location": snapshot_url + os.path.basename(filename)} headers = {"Content-Type": "application/json"} response = requests.put( snapshot_url + "recover", data=json.dumps(data), headers=headers ) if response.status_code != 200: raise RuntimeError("Loading Qdrant Snapshot Failed") def save_typesense(filename: str) -> None: logger.info("Attempting to take Typesense snapshot") ts_client = get_typesense_client() all_docs = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.export() with open(filename, "w") as f: f.write(all_docs) def load_typesense(filename: str) -> None: logger.info("Attempting to load Typesense snapshot") ts_client = get_typesense_client() try: ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].delete() except ObjectNotFound: pass create_typesense_collection(TYPESENSE_DEFAULT_COLLECTION) with open(filename) as jsonl_file: ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.import_( jsonl_file.read().encode("utf-8"), {"action": "create"} ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Danswer checkpoint saving and loading." ) parser.add_argument( "--save", action="store_true", help="Save Danswer state to directory." ) parser.add_argument( "--load", action="store_true", help="Load Danswer state from save directory." ) parser.add_argument( "--checkpoint_dir", type=str, default=os.path.join("..", "danswer_checkpoint"), help="A directory to store temporary files to.", ) args = parser.parse_args() checkpoint_dir = args.checkpoint_dir if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) if not args.save and not args.load: raise ValueError("Must specify --save or --load") if args.load: load_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar")) load_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot")) load_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl")) else: save_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar")) save_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot")) save_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl"))