Vespa Save and Load (#422)

This commit is contained in:
Yuhong Sun 2023-09-09 20:25:31 -07:00 committed by GitHub
parent 0e65688166
commit 4a0c2bf866
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 5 deletions

View File

@ -34,7 +34,7 @@ logger = setup_logger()
class QAHandler(abc.ABC):
"""Evolution of the `PromptProcessor` - handles both building the prompt and
processing the response. These are neccessarily coupled, since the prompt determines
processing the response. These are necessarily coupled, since the prompt determines
the response format (and thus how it should be parsed into an answer + quotes)."""
@abc.abstractmethod

View File

@ -22,6 +22,7 @@ from danswer.configs.app_configs import QDRANT_PORT
from danswer.datastores.qdrant.utils import create_qdrant_collection
from danswer.datastores.qdrant.utils import list_qdrant_collections
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.datastores.vespa.store import DOCUMENT_ID_ENDPOINT
from danswer.utils.clients import get_qdrant_client
from danswer.utils.clients import get_typesense_client
from danswer.utils.logger import setup_logger
@ -124,6 +125,41 @@ def load_typesense(filename: str) -> None:
)
def save_vespa(filename: str) -> None:
logger.info("Attempting to take Vespa snapshot")
continuation = ""
params = {}
doc_jsons: list[dict] = []
while continuation is not None:
if continuation:
params = {"continuation": continuation}
response = requests.get(DOCUMENT_ID_ENDPOINT, params=params)
response.raise_for_status()
found = response.json()
continuation = found.get("continuation")
docs = found["documents"]
for doc in docs:
doc_json = {"update": doc["id"], "create": True, "fields": doc["fields"]}
doc_jsons.append(doc_json)
with open(filename, "w") as jsonl_file:
for doc in doc_jsons:
json_str = json.dumps(doc)
jsonl_file.write(json_str + "\n")
def load_vespa(filename: str) -> None:
headers = {"Content-Type": "application/json"}
with open(filename, "r") as f:
for line in f:
new_doc = json.loads(line.strip())
doc_id = new_doc["update"].split("::")[-1]
response = requests.post(
DOCUMENT_ID_ENDPOINT + "/" + doc_id, headers=headers, json=new_doc
)
response.raise_for_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Danswer checkpoint saving and loading."
@ -152,9 +188,11 @@ if __name__ == "__main__":
if args.load:
load_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar"))
load_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot"))
load_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl"))
load_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
# load_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot"))
# load_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl"))
else:
save_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar"))
save_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot"))
save_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl"))
save_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
# save_qdrant(os.path.join(checkpoint_dir, "qdrant.snapshot"))
# save_typesense(os.path.join(checkpoint_dir, "typesense_snapshot.jsonl"))