Catch dropped eval questions and added multiprocessing (#1849)

This commit is contained in:
hagen-danswer 2024-07-16 12:33:02 -07:00 committed by GitHub
parent 25b3dacaba
commit b83f435bb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 102 additions and 81 deletions

View File

@ -33,6 +33,7 @@ def translate_doc_response_to_simple_doc(
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,

View File

@ -44,6 +44,7 @@ class BasicCreateChatMessageRequest(ChunkContext):
class SimpleDoc(BaseModel):
id: str
semantic_identifier: str
link: str | None
blurb: str

View File

@ -2,23 +2,43 @@ import requests
from retry import retry
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingStatus
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import IndexFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.server.documents.models import ConnectorBase
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
from tests.regression.answer_quality.cli_utils import restart_vespa_container
GENERAL_HEADERS = {"Content-Type": "application/json"}
def _api_url_builder(run_suffix: str, api_path: str) -> str:
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
def _create_new_chat_session(run_suffix: str) -> int:
create_chat_request = ChatSessionCreationRequest(
persona_id=0,
description=None,
)
body = create_chat_request.dict()
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
response_json = requests.post(
create_chat_url, headers=GENERAL_HEADERS, json=body
).json()
chat_session_id = response_json.get("chat_session_id")
if isinstance(chat_session_id, int):
return chat_session_id
else:
raise RuntimeError(response_json)
@retry(tries=15, delay=10, jitter=1)
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
filters = IndexFilters(
@ -28,51 +48,43 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
tags=None,
access_control_list=None,
)
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
new_message_request = DirectQARequest(
messages=messages,
prompt_id=0,
persona_id=0,
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
),
chain_of_thought=False,
return_contexts=True,
retrieval_options = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
)
url = _api_url_builder(run_suffix, "/query/answer-with-quote/")
headers = {
"Content-Type": "application/json",
}
chat_session_id = _create_new_chat_session(run_suffix)
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
new_message_request = BasicCreateChatMessageRequest(
chat_session_id=chat_session_id,
message=query,
retrieval_options=retrieval_options,
query_override=query,
)
body = new_message_request.dict()
body["user"] = None
try:
response_json = requests.post(url, headers=headers, json=body).json()
context_data_list = response_json.get("contexts", {}).get("contexts", [])
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
simple_search_docs = response_json.get("simple_search_docs", [])
answer = response_json.get("answer", "")
except Exception as e:
print("Failed to answer the questions:")
print(f"\t {str(e)}")
print("Restarting vespa container and trying agian")
restart_vespa_container(run_suffix)
print("trying again")
raise e
return context_data_list, answer
return simple_search_docs, answer
def check_if_query_ready(run_suffix: str) -> bool:
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
headers = {
"Content-Type": "application/json",
}
indexing_status_dict = requests.get(url, headers=headers).json()
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
ongoing_index_attempts = False
doc_count = 0
@ -94,17 +106,13 @@ def check_if_query_ready(run_suffix: str) -> bool:
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
headers = {
"Content-Type": "application/json",
}
body = {
"connector_id": connector_id,
"credential_ids": [credential_id],
"from_beginning": True,
}
print("body:", body)
response = requests.post(url, headers=headers, json=body)
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
print("Connector created successfully:", response.json())
else:
@ -116,13 +124,10 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
url = _api_url_builder(
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
)
headers = {
"Content-Type": "application/json",
}
body = {"name": "zip_folder_contents", "is_public": True}
print("body:", body)
response = requests.put(url, headers=headers, json=body)
response = requests.put(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
print("Connector created successfully:", response.json())
else:
@ -132,14 +137,12 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
def _get_existing_connector_names(run_suffix: str) -> list[str]:
url = _api_url_builder(run_suffix, "/manage/connector")
headers = {
"Content-Type": "application/json",
}
body = {
"credential_json": {},
"admin_public": True,
}
response = requests.get(url, headers=headers, json=body)
response = requests.get(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
connectors = response.json()
return [connector["name"] for connector in connectors]
@ -149,9 +152,6 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
url = _api_url_builder(run_suffix, "/manage/admin/connector")
headers = {
"Content-Type": "application/json",
}
connector_name = base_connector_name = "search_eval_connector"
existing_connector_names = _get_existing_connector_names(run_suffix)
@ -172,7 +172,7 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
body = connector.dict()
print("body:", body)
response = requests.post(url, headers=headers, json=body)
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
print("Connector created successfully:", response.json())
return response.json()["id"]
@ -182,14 +182,11 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
def create_credential(run_suffix: str) -> int:
url = _api_url_builder(run_suffix, "/manage/credential")
headers = {
"Content-Type": "application/json",
}
body = {
"credential_json": {},
"admin_public": True,
}
response = requests.post(url, headers=headers, json=body)
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
print("credential created successfully:", response.json())
return response.json()["id"]

View File

@ -1,4 +1,5 @@
import json
import multiprocessing
import os
import time
@ -13,11 +14,12 @@ RESULTS_FILENAME = "results.jsonl"
METADATA_FILENAME = "metadata.yaml"
def _update_results_file(output_folder_path: str, qa_output: dict) -> None:
def _populate_results_file(output_folder_path: str, all_qa_output: list[dict]) -> None:
output_file_path = os.path.join(output_folder_path, RESULTS_FILENAME)
with open(output_file_path, "w", encoding="utf-8") as file:
file.write(json.dumps(qa_output) + "\n")
file.flush()
with open(output_file_path, "a", encoding="utf-8") as file:
for qa_output in all_qa_output:
file.write(json.dumps(qa_output) + "\n")
file.flush()
def _update_metadata_file(test_output_folder: str, count: int) -> None:
@ -81,8 +83,8 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
del env_vars["ENV_SEED_CONFIGURATION"]
if env_vars["GPG_KEY"]:
del env_vars["GPG_KEY"]
if metadata["config"]["llm"]["api_key"]:
del metadata["config"]["llm"]["api_key"]
if metadata["test_config"]["llm"]["api_key"]:
del metadata["test_config"]["llm"]["api_key"]
metadata.update(env_vars)
metadata_path = os.path.join(test_output_folder, METADATA_FILENAME)
print("saving metadata to:", metadata_path)
@ -92,7 +94,34 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
return test_output_folder, questions
def _process_question(question_data: dict, config: dict, question_number: int) -> dict:
print(f"On question number {question_number}")
query = question_data["question"]
print(f"query: {query}")
context_data_list, answer = get_answer_from_query(
query=query,
run_suffix=config["run_suffix"],
)
if not context_data_list:
print("No answer or context found")
else:
print(f"answer: {answer[:50]}...")
print(f"{len(context_data_list)} context docs found")
print("\n")
output = {
"question_data": question_data,
"answer": answer,
"context_data_list": context_data_list,
}
return output
def _process_and_write_query_results(config: dict) -> None:
start_time = time.time()
test_output_folder, questions = _initialize_files(config)
print("saving test results to folder:", test_output_folder)
@ -101,33 +130,26 @@ def _process_and_write_query_results(config: dict) -> None:
if config["limit"] is not None:
questions = questions[: config["limit"]]
count = 1
for question_data in questions:
print(f"On question number {count}")
query = question_data["question"]
print(f"query: {query}")
context_data_list, answer = get_answer_from_query(
query=query,
run_suffix=config["run_suffix"],
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
results = pool.starmap(
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
)
if not context_data_list:
print("No answer or context found")
else:
print(f"answer: {answer[:50]}...")
print(f"{len(context_data_list)} context docs found")
print("\n")
_populate_results_file(test_output_folder, results)
output = {
"question_data": question_data,
"answer": answer,
"context_data_list": context_data_list,
}
valid_answer_count = 0
for result in results:
if result.get("answer"):
valid_answer_count += 1
_update_results_file(test_output_folder, output)
_update_metadata_file(test_output_folder, count)
count += 1
_update_metadata_file(test_output_folder, valid_answer_count)
time_to_finish = time.time() - start_time
minutes, seconds = divmod(int(time_to_finish), 60)
print(
f"Took {minutes:02d}:{seconds:02d} to ask and answer {len(results)} questions"
)
def run_qa_test_and_save_results(run_suffix: str = "") -> None: