Search Regression Test and Save/Load State updates (#761)

This commit is contained in:
Yuhong Sun
2023-11-23 00:00:30 -08:00
committed by GitHub
parent fda377a2fa
commit 13001ede98
5 changed files with 138 additions and 46 deletions

View File

@ -1,4 +1,5 @@
# This file is purely for development use, not included in any builds
# Remember to first to send over the schema information (run API Server)
import argparse
import json
import os
@ -19,25 +20,40 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def save_postgres(filename: str) -> None:
def save_postgres(filename: str, container_name: str) -> None:
logger.info("Attempting to take Postgres snapshot")
cmd = f"pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB} > {filename}"
subprocess.run(
cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True
)
cmd = f"docker exec {container_name} pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB}"
with open(filename, "w") as file:
subprocess.run(
cmd,
shell=True,
check=True,
stdout=file,
text=True,
input=f"{POSTGRES_PASSWORD}\n",
)
def load_postgres(filename: str) -> None:
def load_postgres(filename: str, container_name: str) -> None:
logger.info("Attempting to load Postgres snapshot")
try:
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception:
logger.info("Alembic upgrade failed, maybe already has run")
cmd = f"pg_restore --clean -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -d {POSTGRES_DB} -1 {filename}"
subprocess.run(
cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True
except Exception as e:
logger.error(f"Alembic upgrade failed: {e}")
host_file_path = os.path.abspath(filename)
copy_cmd = f"docker cp {host_file_path} {container_name}:/tmp/"
subprocess.run(copy_cmd, shell=True, check=True)
container_file_path = f"/tmp/{os.path.basename(filename)}"
restore_cmd = (
f"docker exec {container_name} pg_restore --clean -U {POSTGRES_USER} "
f"-h localhost -p {POSTGRES_PORT} -d {POSTGRES_DB} -1 -F t {container_file_path}"
)
subprocess.run(restore_cmd, shell=True, check=True)
def save_vespa(filename: str) -> None:
@ -85,6 +101,12 @@ if __name__ == "__main__":
parser.add_argument(
"--load", action="store_true", help="Load Danswer state from save directory."
)
parser.add_argument(
"--postgres_container_name",
type=str,
default="danswer-stack-relational_db-1",
help="Name of the postgres container to dump",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
@ -94,6 +116,7 @@ if __name__ == "__main__":
args = parser.parse_args()
checkpoint_dir = args.checkpoint_dir
postgres_container = args.postgres_container_name
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
@ -102,8 +125,12 @@ if __name__ == "__main__":
raise ValueError("Must specify --save or --load")
if args.load:
load_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar"))
load_postgres(
os.path.join(checkpoint_dir, "postgres_snapshot.tar"), postgres_container
)
load_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
else:
save_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar"))
save_postgres(
os.path.join(checkpoint_dir, "postgres_snapshot.tar"), postgres_container
)
save_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))