Catch dropped eval questions and added multiprocessing (#1849)

This commit is contained in:
hagen-danswer
2024-07-16 12:33:02 -07:00
committed by GitHub
parent 25b3dacaba
commit b83f435bb0
4 changed files with 102 additions and 81 deletions

View File

@@ -33,6 +33,7 @@ def translate_doc_response_to_simple_doc(
) -> list[SimpleDoc]: ) -> list[SimpleDoc]:
return [ return [
SimpleDoc( SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier, semantic_identifier=doc.semantic_identifier,
link=doc.link, link=doc.link,
blurb=doc.blurb, blurb=doc.blurb,

View File

@@ -44,6 +44,7 @@ class BasicCreateChatMessageRequest(ChunkContext):
class SimpleDoc(BaseModel): class SimpleDoc(BaseModel):
id: str
semantic_identifier: str semantic_identifier: str
link: str | None link: str | None
blurb: str blurb: str

View File

@@ -2,23 +2,43 @@ import requests
from retry import retry from retry import retry
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.connectors.models import InputType from danswer.connectors.models import InputType
from danswer.db.enums import IndexingStatus from danswer.db.enums import IndexingStatus
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import IndexFilters from danswer.search.models import IndexFilters
from danswer.search.models import OptionalSearchSetting from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
from danswer.server.documents.models import ConnectorBase from danswer.server.documents.models import ConnectorBase
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
from tests.regression.answer_quality.cli_utils import get_api_server_host_port from tests.regression.answer_quality.cli_utils import get_api_server_host_port
from tests.regression.answer_quality.cli_utils import restart_vespa_container
GENERAL_HEADERS = {"Content-Type": "application/json"}
def _api_url_builder(run_suffix: str, api_path: str) -> str: def _api_url_builder(run_suffix: str, api_path: str) -> str:
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
def _create_new_chat_session(run_suffix: str) -> int:
create_chat_request = ChatSessionCreationRequest(
persona_id=0,
description=None,
)
body = create_chat_request.dict()
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
response_json = requests.post(
create_chat_url, headers=GENERAL_HEADERS, json=body
).json()
chat_session_id = response_json.get("chat_session_id")
if isinstance(chat_session_id, int):
return chat_session_id
else:
raise RuntimeError(response_json)
@retry(tries=15, delay=10, jitter=1) @retry(tries=15, delay=10, jitter=1)
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]: def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
filters = IndexFilters( filters = IndexFilters(
@@ -28,51 +48,43 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
tags=None, tags=None,
access_control_list=None, access_control_list=None,
) )
retrieval_options = RetrievalDetails(
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)] run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
new_message_request = DirectQARequest( filters=filters,
messages=messages, enable_auto_detect_filters=False,
prompt_id=0,
persona_id=0,
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
),
chain_of_thought=False,
return_contexts=True,
) )
url = _api_url_builder(run_suffix, "/query/answer-with-quote/") chat_session_id = _create_new_chat_session(run_suffix)
headers = {
"Content-Type": "application/json", url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
}
new_message_request = BasicCreateChatMessageRequest(
chat_session_id=chat_session_id,
message=query,
retrieval_options=retrieval_options,
query_override=query,
)
body = new_message_request.dict() body = new_message_request.dict()
body["user"] = None body["user"] = None
try: try:
response_json = requests.post(url, headers=headers, json=body).json() response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
context_data_list = response_json.get("contexts", {}).get("contexts", []) simple_search_docs = response_json.get("simple_search_docs", [])
answer = response_json.get("answer", "") answer = response_json.get("answer", "")
except Exception as e: except Exception as e:
print("Failed to answer the questions:") print("Failed to answer the questions:")
print(f"\t {str(e)}") print(f"\t {str(e)}")
print("Restarting vespa container and trying agian") print("trying again")
restart_vespa_container(run_suffix)
raise e raise e
return context_data_list, answer return simple_search_docs, answer
def check_if_query_ready(run_suffix: str) -> bool: def check_if_query_ready(run_suffix: str) -> bool:
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/") url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
headers = {
"Content-Type": "application/json",
}
indexing_status_dict = requests.get(url, headers=headers).json() indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
ongoing_index_attempts = False ongoing_index_attempts = False
doc_count = 0 doc_count = 0
@@ -94,17 +106,13 @@ def check_if_query_ready(run_suffix: str) -> bool:
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None: def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/") url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
headers = {
"Content-Type": "application/json",
}
body = { body = {
"connector_id": connector_id, "connector_id": connector_id,
"credential_ids": [credential_id], "credential_ids": [credential_id],
"from_beginning": True, "from_beginning": True,
} }
print("body:", body) print("body:", body)
response = requests.post(url, headers=headers, json=body) response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200: if response.status_code == 200:
print("Connector created successfully:", response.json()) print("Connector created successfully:", response.json())
else: else:
@@ -116,13 +124,10 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
url = _api_url_builder( url = _api_url_builder(
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}" run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
) )
headers = {
"Content-Type": "application/json",
}
body = {"name": "zip_folder_contents", "is_public": True} body = {"name": "zip_folder_contents", "is_public": True}
print("body:", body) print("body:", body)
response = requests.put(url, headers=headers, json=body) response = requests.put(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200: if response.status_code == 200:
print("Connector created successfully:", response.json()) print("Connector created successfully:", response.json())
else: else:
@@ -132,14 +137,12 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
def _get_existing_connector_names(run_suffix: str) -> list[str]: def _get_existing_connector_names(run_suffix: str) -> list[str]:
url = _api_url_builder(run_suffix, "/manage/connector") url = _api_url_builder(run_suffix, "/manage/connector")
headers = {
"Content-Type": "application/json",
}
body = { body = {
"credential_json": {}, "credential_json": {},
"admin_public": True, "admin_public": True,
} }
response = requests.get(url, headers=headers, json=body) response = requests.get(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200: if response.status_code == 200:
connectors = response.json() connectors = response.json()
return [connector["name"] for connector in connectors] return [connector["name"] for connector in connectors]
@@ -149,9 +152,6 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
def create_connector(run_suffix: str, file_paths: list[str]) -> int: def create_connector(run_suffix: str, file_paths: list[str]) -> int:
url = _api_url_builder(run_suffix, "/manage/admin/connector") url = _api_url_builder(run_suffix, "/manage/admin/connector")
headers = {
"Content-Type": "application/json",
}
connector_name = base_connector_name = "search_eval_connector" connector_name = base_connector_name = "search_eval_connector"
existing_connector_names = _get_existing_connector_names(run_suffix) existing_connector_names = _get_existing_connector_names(run_suffix)
@@ -172,7 +172,7 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
body = connector.dict() body = connector.dict()
print("body:", body) print("body:", body)
response = requests.post(url, headers=headers, json=body) response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200: if response.status_code == 200:
print("Connector created successfully:", response.json()) print("Connector created successfully:", response.json())
return response.json()["id"] return response.json()["id"]
@@ -182,14 +182,11 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
def create_credential(run_suffix: str) -> int: def create_credential(run_suffix: str) -> int:
url = _api_url_builder(run_suffix, "/manage/credential") url = _api_url_builder(run_suffix, "/manage/credential")
headers = {
"Content-Type": "application/json",
}
body = { body = {
"credential_json": {}, "credential_json": {},
"admin_public": True, "admin_public": True,
} }
response = requests.post(url, headers=headers, json=body) response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200: if response.status_code == 200:
print("credential created successfully:", response.json()) print("credential created successfully:", response.json())
return response.json()["id"] return response.json()["id"]

View File

@@ -1,4 +1,5 @@
import json import json
import multiprocessing
import os import os
import time import time
@@ -13,11 +14,12 @@ RESULTS_FILENAME = "results.jsonl"
METADATA_FILENAME = "metadata.yaml" METADATA_FILENAME = "metadata.yaml"
def _update_results_file(output_folder_path: str, qa_output: dict) -> None: def _populate_results_file(output_folder_path: str, all_qa_output: list[dict]) -> None:
output_file_path = os.path.join(output_folder_path, RESULTS_FILENAME) output_file_path = os.path.join(output_folder_path, RESULTS_FILENAME)
with open(output_file_path, "w", encoding="utf-8") as file: with open(output_file_path, "a", encoding="utf-8") as file:
file.write(json.dumps(qa_output) + "\n") for qa_output in all_qa_output:
file.flush() file.write(json.dumps(qa_output) + "\n")
file.flush()
def _update_metadata_file(test_output_folder: str, count: int) -> None: def _update_metadata_file(test_output_folder: str, count: int) -> None:
@@ -81,8 +83,8 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
del env_vars["ENV_SEED_CONFIGURATION"] del env_vars["ENV_SEED_CONFIGURATION"]
if env_vars["GPG_KEY"]: if env_vars["GPG_KEY"]:
del env_vars["GPG_KEY"] del env_vars["GPG_KEY"]
if metadata["config"]["llm"]["api_key"]: if metadata["test_config"]["llm"]["api_key"]:
del metadata["config"]["llm"]["api_key"] del metadata["test_config"]["llm"]["api_key"]
metadata.update(env_vars) metadata.update(env_vars)
metadata_path = os.path.join(test_output_folder, METADATA_FILENAME) metadata_path = os.path.join(test_output_folder, METADATA_FILENAME)
print("saving metadata to:", metadata_path) print("saving metadata to:", metadata_path)
@@ -92,7 +94,34 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
return test_output_folder, questions return test_output_folder, questions
def _process_question(question_data: dict, config: dict, question_number: int) -> dict:
print(f"On question number {question_number}")
query = question_data["question"]
print(f"query: {query}")
context_data_list, answer = get_answer_from_query(
query=query,
run_suffix=config["run_suffix"],
)
if not context_data_list:
print("No answer or context found")
else:
print(f"answer: {answer[:50]}...")
print(f"{len(context_data_list)} context docs found")
print("\n")
output = {
"question_data": question_data,
"answer": answer,
"context_data_list": context_data_list,
}
return output
def _process_and_write_query_results(config: dict) -> None: def _process_and_write_query_results(config: dict) -> None:
start_time = time.time()
test_output_folder, questions = _initialize_files(config) test_output_folder, questions = _initialize_files(config)
print("saving test results to folder:", test_output_folder) print("saving test results to folder:", test_output_folder)
@@ -101,33 +130,26 @@ def _process_and_write_query_results(config: dict) -> None:
if config["limit"] is not None: if config["limit"] is not None:
questions = questions[: config["limit"]] questions = questions[: config["limit"]]
count = 1
for question_data in questions:
print(f"On question number {count}")
query = question_data["question"] with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
print(f"query: {query}") results = pool.starmap(
context_data_list, answer = get_answer_from_query( _process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
query=query,
run_suffix=config["run_suffix"],
) )
if not context_data_list: _populate_results_file(test_output_folder, results)
print("No answer or context found")
else:
print(f"answer: {answer[:50]}...")
print(f"{len(context_data_list)} context docs found")
print("\n")
output = { valid_answer_count = 0
"question_data": question_data, for result in results:
"answer": answer, if result.get("answer"):
"context_data_list": context_data_list, valid_answer_count += 1
}
_update_results_file(test_output_folder, output) _update_metadata_file(test_output_folder, valid_answer_count)
_update_metadata_file(test_output_folder, count)
count += 1 time_to_finish = time.time() - start_time
minutes, seconds = divmod(int(time_to_finish), 60)
print(
f"Took {minutes:02d}:{seconds:02d} to ask and answer {len(results)} questions"
)
def run_qa_test_and_save_results(run_suffix: str = "") -> None: def run_qa_test_and_save_results(run_suffix: str = "") -> None: