danswer/backend/scripts/save_load_state.py
2024-08-18 21:53:40 -07:00

139 lines
4.6 KiB
Python

# This file is purely for development use, not included in any builds
# Remember to first to send over the schema information (run API Server)
import argparse
import json
import os
import subprocess
import requests
from alembic import command
from alembic.config import Config
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.utils.logger import setup_logger
logger = setup_logger()
def save_postgres(filename: str, container_name: str) -> None:
logger.notice("Attempting to take Postgres snapshot")
cmd = f"docker exec {container_name} pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB}"
with open(filename, "w") as file:
subprocess.run(
cmd,
shell=True,
check=True,
stdout=file,
text=True,
input=f"{POSTGRES_PASSWORD}\n",
)
def load_postgres(filename: str, container_name: str) -> None:
logger.notice("Attempting to load Postgres snapshot")
try:
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
logger.error(f"Alembic upgrade failed: {e}")
host_file_path = os.path.abspath(filename)
copy_cmd = f"docker cp {host_file_path} {container_name}:/tmp/"
subprocess.run(copy_cmd, shell=True, check=True)
container_file_path = f"/tmp/{os.path.basename(filename)}"
restore_cmd = (
f"docker exec {container_name} pg_restore --clean -U {POSTGRES_USER} "
f"-h localhost -p {POSTGRES_PORT} -d {POSTGRES_DB} -1 -F t {container_file_path}"
)
subprocess.run(restore_cmd, shell=True, check=True)
def save_vespa(filename: str) -> None:
logger.notice("Attempting to take Vespa snapshot")
continuation = ""
params = {}
doc_jsons: list[dict] = []
while continuation is not None:
if continuation:
params = {"continuation": continuation}
response = requests.get(DOCUMENT_ID_ENDPOINT, params=params)
response.raise_for_status()
found = response.json()
continuation = found.get("continuation")
docs = found["documents"]
for doc in docs:
doc_json = {"update": doc["id"], "create": True, "fields": doc["fields"]}
doc_jsons.append(doc_json)
with open(filename, "w") as jsonl_file:
for doc in doc_jsons:
json_str = json.dumps(doc)
jsonl_file.write(json_str + "\n")
def load_vespa(filename: str) -> None:
headers = {"Content-Type": "application/json"}
with open(filename, "r") as f:
for line in f:
new_doc = json.loads(line.strip())
doc_id = new_doc["update"].split("::")[-1]
response = requests.post(
DOCUMENT_ID_ENDPOINT + "/" + doc_id,
headers=headers,
json=new_doc,
)
response.raise_for_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Danswer checkpoint saving and loading."
)
parser.add_argument(
"--save", action="store_true", help="Save Danswer state to directory."
)
parser.add_argument(
"--load", action="store_true", help="Load Danswer state from save directory."
)
parser.add_argument(
"--postgres_container_name",
type=str,
default="danswer-stack-relational_db-1",
help="Name of the postgres container to dump",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default=os.path.join("..", "danswer_checkpoint"),
help="A directory to store temporary files to.",
)
args = parser.parse_args()
checkpoint_dir = args.checkpoint_dir
postgres_container = args.postgres_container_name
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
if not args.save and not args.load:
raise ValueError("Must specify --save or --load")
if args.load:
load_postgres(
os.path.join(checkpoint_dir, "postgres_snapshot.tar"), postgres_container
)
load_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
else:
save_postgres(
os.path.join(checkpoint_dir, "postgres_snapshot.tar"), postgres_container
)
save_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))