From 2fb2c40851c6e93872889701d36efffe8e936942 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 24 Jun 2023 00:00:06 -0700 Subject: [PATCH] DAN-145 Danswer save state (#115) --- backend/scripts/save_load_state.py | 160 +++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 backend/scripts/save_load_state.py diff --git a/backend/scripts/save_load_state.py b/backend/scripts/save_load_state.py new file mode 100644 index 000000000..de75e78c7 --- /dev/null +++ b/backend/scripts/save_load_state.py @@ -0,0 +1,160 @@ +# This file is purely for development use, not included in any builds +import argparse +import json +import os +import subprocess +from datetime import datetime + +import requests +from alembic import command +from alembic.config import Config +from danswer.configs.app_configs import POSTGRES_DB +from danswer.configs.app_configs import POSTGRES_HOST +from danswer.configs.app_configs import POSTGRES_PASSWORD +from danswer.configs.app_configs import POSTGRES_PORT +from danswer.configs.app_configs import POSTGRES_USER +from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION +from danswer.configs.app_configs import QDRANT_HOST +from danswer.configs.app_configs import QDRANT_PORT +from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION +from danswer.datastores.qdrant.indexing import create_qdrant_collection +from danswer.datastores.qdrant.indexing import list_qdrant_collections +from danswer.datastores.typesense.store import create_typesense_collection +from danswer.utils.clients import get_qdrant_client +from danswer.utils.clients import get_typesense_client +from danswer.utils.logging import setup_logger +from qdrant_client.http.models.models import SnapshotDescription +from typesense.exceptions import ObjectNotFound # type: ignore + +logger = setup_logger() + + +def save_postgres(filename: str) -> None: + logger.info("Attempting to take Postgres snapshot") + cmd = f"pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB} > {filename}" + subprocess.run( + cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True + ) + + +def load_postgres(filename: str) -> None: + logger.info("Attempting to load Postgres snapshot") + try: + alembic_cfg = Config("alembic.ini") + command.upgrade(alembic_cfg, "head") + except Exception as e: + logger.info("Alembic upgrade failed, maybe already has run") + cmd = f"pg_restore --clean -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -d {POSTGRES_DB} -1 {filename}" + subprocess.run( + cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True + ) + + +def snapshot_time_compare(snap: SnapshotDescription) -> datetime: + if not hasattr(snap, "creation_time") or snap.creation_time is None: + raise RuntimeError("Qdrant Snapshots Failed") + return datetime.strptime(snap.creation_time, "%Y-%m-%dT%H:%M:%S") + + +def save_qdrant(filename: str) -> None: + logger.info("Attempting to take Qdrant snapshot") + qdrant_client = get_qdrant_client() + qdrant_client.create_snapshot(collection_name=QDRANT_DEFAULT_COLLECTION) + snapshots = qdrant_client.list_snapshots(collection_name=QDRANT_DEFAULT_COLLECTION) + valid_snapshots = [snap for snap in snapshots if snap.creation_time is not None] + + sorted_snapshots = sorted(valid_snapshots, key=snapshot_time_compare) + last_snapshot_name = sorted_snapshots[-1].name + url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/{last_snapshot_name}" + + response = requests.get(url, stream=True) + + if response.status_code != 200: + raise RuntimeError("Qdrant Save Failed") + + with open(filename, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + + +def load_qdrant(filename: str) -> None: + logger.info("Attempting to load Qdrant snapshot") + if QDRANT_DEFAULT_COLLECTION not in { + collection.name for collection in list_qdrant_collections().collections + }: + create_qdrant_collection(QDRANT_DEFAULT_COLLECTION) + snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/" + + with open(filename, "rb") as f: + files = {"snapshot": (os.path.basename(filename), f)} + response = requests.post(snapshot_url + "upload", files=files) + if response.status_code != 200: + raise RuntimeError("Qdrant Snapshot Upload Failed") + + data = {"location": snapshot_url + os.path.basename(filename)} + headers = {"Content-Type": "application/json"} + response = requests.put( + snapshot_url + "recover", data=json.dumps(data), headers=headers + ) + if response.status_code != 200: + raise RuntimeError("Loading Qdrant Snapshot Failed") + + +def save_typesense(filename: str) -> None: + logger.info("Attempting to take Typesense snapshot") + ts_client = get_typesense_client() + all_docs = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.export() + with open(filename, "w") as f: + f.write(all_docs) + + +def load_typesense(filename: str) -> None: + logger.info("Attempting to load Typesense snapshot") + ts_client = get_typesense_client() + try: + ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].delete() + except ObjectNotFound: + pass + + create_typesense_collection(TYPESENSE_DEFAULT_COLLECTION) + + with open(filename) as jsonl_file: + ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.import_( + jsonl_file.read().encode("utf-8"), {"action": "create"} + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Danswer checkpoint saving and loading." + ) + parser.add_argument( + "--save", action="store_true", help="Save Danswer state to directory." + ) + parser.add_argument( + "--load", action="store_true", help="Load Danswer state from save directory." + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=os.path.join("..", "danswer_checkpoint"), + help="A directory to store temporary files to.", + ) + + args = parser.parse_args() + checkpoint_dir = args.checkpoint_dir + + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + if not args.save and not args.load: + raise ValueError("Must specify --save or --load") + + if args.load: + load_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar")) + load_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot")) + load_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl")) + else: + save_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar")) + save_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot")) + save_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl"))