mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-21 14:12:42 +02:00
Catch dropped eval questions and added multiprocessing (#1849)
This commit is contained in:
@@ -33,6 +33,7 @@ def translate_doc_response_to_simple_doc(
|
|||||||
) -> list[SimpleDoc]:
|
) -> list[SimpleDoc]:
|
||||||
return [
|
return [
|
||||||
SimpleDoc(
|
SimpleDoc(
|
||||||
|
id=doc.document_id,
|
||||||
semantic_identifier=doc.semantic_identifier,
|
semantic_identifier=doc.semantic_identifier,
|
||||||
link=doc.link,
|
link=doc.link,
|
||||||
blurb=doc.blurb,
|
blurb=doc.blurb,
|
||||||
|
@@ -44,6 +44,7 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
|||||||
|
|
||||||
|
|
||||||
class SimpleDoc(BaseModel):
|
class SimpleDoc(BaseModel):
|
||||||
|
id: str
|
||||||
semantic_identifier: str
|
semantic_identifier: str
|
||||||
link: str | None
|
link: str | None
|
||||||
blurb: str
|
blurb: str
|
||||||
|
@@ -2,23 +2,43 @@ import requests
|
|||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.constants import MessageType
|
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.db.enums import IndexingStatus
|
from danswer.db.enums import IndexingStatus
|
||||||
from danswer.one_shot_answer.models import DirectQARequest
|
|
||||||
from danswer.one_shot_answer.models import ThreadMessage
|
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import OptionalSearchSetting
|
from danswer.search.models import OptionalSearchSetting
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
from danswer.server.documents.models import ConnectorBase
|
from danswer.server.documents.models import ConnectorBase
|
||||||
|
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||||
|
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||||
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
||||||
from tests.regression.answer_quality.cli_utils import restart_vespa_container
|
|
||||||
|
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
|
||||||
def _api_url_builder(run_suffix: str, api_path: str) -> str:
|
def _api_url_builder(run_suffix: str, api_path: str) -> str:
|
||||||
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
|
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
|
||||||
|
|
||||||
|
|
||||||
|
def _create_new_chat_session(run_suffix: str) -> int:
|
||||||
|
create_chat_request = ChatSessionCreationRequest(
|
||||||
|
persona_id=0,
|
||||||
|
description=None,
|
||||||
|
)
|
||||||
|
body = create_chat_request.dict()
|
||||||
|
|
||||||
|
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
|
||||||
|
|
||||||
|
response_json = requests.post(
|
||||||
|
create_chat_url, headers=GENERAL_HEADERS, json=body
|
||||||
|
).json()
|
||||||
|
chat_session_id = response_json.get("chat_session_id")
|
||||||
|
|
||||||
|
if isinstance(chat_session_id, int):
|
||||||
|
return chat_session_id
|
||||||
|
else:
|
||||||
|
raise RuntimeError(response_json)
|
||||||
|
|
||||||
|
|
||||||
@retry(tries=15, delay=10, jitter=1)
|
@retry(tries=15, delay=10, jitter=1)
|
||||||
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||||
filters = IndexFilters(
|
filters = IndexFilters(
|
||||||
@@ -28,51 +48,43 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
|||||||
tags=None,
|
tags=None,
|
||||||
access_control_list=None,
|
access_control_list=None,
|
||||||
)
|
)
|
||||||
|
retrieval_options = RetrievalDetails(
|
||||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
run_search=OptionalSearchSetting.ALWAYS,
|
||||||
|
real_time=True,
|
||||||
new_message_request = DirectQARequest(
|
filters=filters,
|
||||||
messages=messages,
|
enable_auto_detect_filters=False,
|
||||||
prompt_id=0,
|
|
||||||
persona_id=0,
|
|
||||||
retrieval_options=RetrievalDetails(
|
|
||||||
run_search=OptionalSearchSetting.ALWAYS,
|
|
||||||
real_time=True,
|
|
||||||
filters=filters,
|
|
||||||
enable_auto_detect_filters=False,
|
|
||||||
),
|
|
||||||
chain_of_thought=False,
|
|
||||||
return_contexts=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
url = _api_url_builder(run_suffix, "/query/answer-with-quote/")
|
chat_session_id = _create_new_chat_session(run_suffix)
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
|
||||||
}
|
|
||||||
|
new_message_request = BasicCreateChatMessageRequest(
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
message=query,
|
||||||
|
retrieval_options=retrieval_options,
|
||||||
|
query_override=query,
|
||||||
|
)
|
||||||
|
|
||||||
body = new_message_request.dict()
|
body = new_message_request.dict()
|
||||||
body["user"] = None
|
body["user"] = None
|
||||||
try:
|
try:
|
||||||
response_json = requests.post(url, headers=headers, json=body).json()
|
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
|
||||||
context_data_list = response_json.get("contexts", {}).get("contexts", [])
|
simple_search_docs = response_json.get("simple_search_docs", [])
|
||||||
answer = response_json.get("answer", "")
|
answer = response_json.get("answer", "")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to answer the questions:")
|
print("Failed to answer the questions:")
|
||||||
print(f"\t {str(e)}")
|
print(f"\t {str(e)}")
|
||||||
print("Restarting vespa container and trying agian")
|
print("trying again")
|
||||||
restart_vespa_container(run_suffix)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return context_data_list, answer
|
return simple_search_docs, answer
|
||||||
|
|
||||||
|
|
||||||
def check_if_query_ready(run_suffix: str) -> bool:
|
def check_if_query_ready(run_suffix: str) -> bool:
|
||||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
|
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
indexing_status_dict = requests.get(url, headers=headers).json()
|
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
|
||||||
|
|
||||||
ongoing_index_attempts = False
|
ongoing_index_attempts = False
|
||||||
doc_count = 0
|
doc_count = 0
|
||||||
@@ -94,17 +106,13 @@ def check_if_query_ready(run_suffix: str) -> bool:
|
|||||||
|
|
||||||
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
|
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"connector_id": connector_id,
|
"connector_id": connector_id,
|
||||||
"credential_ids": [credential_id],
|
"credential_ids": [credential_id],
|
||||||
"from_beginning": True,
|
"from_beginning": True,
|
||||||
}
|
}
|
||||||
print("body:", body)
|
print("body:", body)
|
||||||
response = requests.post(url, headers=headers, json=body)
|
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("Connector created successfully:", response.json())
|
print("Connector created successfully:", response.json())
|
||||||
else:
|
else:
|
||||||
@@ -116,13 +124,10 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
|||||||
url = _api_url_builder(
|
url = _api_url_builder(
|
||||||
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||||
)
|
)
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
body = {"name": "zip_folder_contents", "is_public": True}
|
body = {"name": "zip_folder_contents", "is_public": True}
|
||||||
print("body:", body)
|
print("body:", body)
|
||||||
response = requests.put(url, headers=headers, json=body)
|
response = requests.put(url, headers=GENERAL_HEADERS, json=body)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("Connector created successfully:", response.json())
|
print("Connector created successfully:", response.json())
|
||||||
else:
|
else:
|
||||||
@@ -132,14 +137,12 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
|||||||
|
|
||||||
def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||||
url = _api_url_builder(run_suffix, "/manage/connector")
|
url = _api_url_builder(run_suffix, "/manage/connector")
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
body = {
|
body = {
|
||||||
"credential_json": {},
|
"credential_json": {},
|
||||||
"admin_public": True,
|
"admin_public": True,
|
||||||
}
|
}
|
||||||
response = requests.get(url, headers=headers, json=body)
|
response = requests.get(url, headers=GENERAL_HEADERS, json=body)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
connectors = response.json()
|
connectors = response.json()
|
||||||
return [connector["name"] for connector in connectors]
|
return [connector["name"] for connector in connectors]
|
||||||
@@ -149,9 +152,6 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
|||||||
|
|
||||||
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||||
url = _api_url_builder(run_suffix, "/manage/admin/connector")
|
url = _api_url_builder(run_suffix, "/manage/admin/connector")
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
connector_name = base_connector_name = "search_eval_connector"
|
connector_name = base_connector_name = "search_eval_connector"
|
||||||
existing_connector_names = _get_existing_connector_names(run_suffix)
|
existing_connector_names = _get_existing_connector_names(run_suffix)
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
|||||||
|
|
||||||
body = connector.dict()
|
body = connector.dict()
|
||||||
print("body:", body)
|
print("body:", body)
|
||||||
response = requests.post(url, headers=headers, json=body)
|
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("Connector created successfully:", response.json())
|
print("Connector created successfully:", response.json())
|
||||||
return response.json()["id"]
|
return response.json()["id"]
|
||||||
@@ -182,14 +182,11 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
|||||||
|
|
||||||
def create_credential(run_suffix: str) -> int:
|
def create_credential(run_suffix: str) -> int:
|
||||||
url = _api_url_builder(run_suffix, "/manage/credential")
|
url = _api_url_builder(run_suffix, "/manage/credential")
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
body = {
|
body = {
|
||||||
"credential_json": {},
|
"credential_json": {},
|
||||||
"admin_public": True,
|
"admin_public": True,
|
||||||
}
|
}
|
||||||
response = requests.post(url, headers=headers, json=body)
|
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("credential created successfully:", response.json())
|
print("credential created successfully:", response.json())
|
||||||
return response.json()["id"]
|
return response.json()["id"]
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -13,11 +14,12 @@ RESULTS_FILENAME = "results.jsonl"
|
|||||||
METADATA_FILENAME = "metadata.yaml"
|
METADATA_FILENAME = "metadata.yaml"
|
||||||
|
|
||||||
|
|
||||||
def _update_results_file(output_folder_path: str, qa_output: dict) -> None:
|
def _populate_results_file(output_folder_path: str, all_qa_output: list[dict]) -> None:
|
||||||
output_file_path = os.path.join(output_folder_path, RESULTS_FILENAME)
|
output_file_path = os.path.join(output_folder_path, RESULTS_FILENAME)
|
||||||
with open(output_file_path, "w", encoding="utf-8") as file:
|
with open(output_file_path, "a", encoding="utf-8") as file:
|
||||||
file.write(json.dumps(qa_output) + "\n")
|
for qa_output in all_qa_output:
|
||||||
file.flush()
|
file.write(json.dumps(qa_output) + "\n")
|
||||||
|
file.flush()
|
||||||
|
|
||||||
|
|
||||||
def _update_metadata_file(test_output_folder: str, count: int) -> None:
|
def _update_metadata_file(test_output_folder: str, count: int) -> None:
|
||||||
@@ -81,8 +83,8 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
|||||||
del env_vars["ENV_SEED_CONFIGURATION"]
|
del env_vars["ENV_SEED_CONFIGURATION"]
|
||||||
if env_vars["GPG_KEY"]:
|
if env_vars["GPG_KEY"]:
|
||||||
del env_vars["GPG_KEY"]
|
del env_vars["GPG_KEY"]
|
||||||
if metadata["config"]["llm"]["api_key"]:
|
if metadata["test_config"]["llm"]["api_key"]:
|
||||||
del metadata["config"]["llm"]["api_key"]
|
del metadata["test_config"]["llm"]["api_key"]
|
||||||
metadata.update(env_vars)
|
metadata.update(env_vars)
|
||||||
metadata_path = os.path.join(test_output_folder, METADATA_FILENAME)
|
metadata_path = os.path.join(test_output_folder, METADATA_FILENAME)
|
||||||
print("saving metadata to:", metadata_path)
|
print("saving metadata to:", metadata_path)
|
||||||
@@ -92,7 +94,34 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
|||||||
return test_output_folder, questions
|
return test_output_folder, questions
|
||||||
|
|
||||||
|
|
||||||
|
def _process_question(question_data: dict, config: dict, question_number: int) -> dict:
|
||||||
|
print(f"On question number {question_number}")
|
||||||
|
|
||||||
|
query = question_data["question"]
|
||||||
|
print(f"query: {query}")
|
||||||
|
context_data_list, answer = get_answer_from_query(
|
||||||
|
query=query,
|
||||||
|
run_suffix=config["run_suffix"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context_data_list:
|
||||||
|
print("No answer or context found")
|
||||||
|
else:
|
||||||
|
print(f"answer: {answer[:50]}...")
|
||||||
|
print(f"{len(context_data_list)} context docs found")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"question_data": question_data,
|
||||||
|
"answer": answer,
|
||||||
|
"context_data_list": context_data_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _process_and_write_query_results(config: dict) -> None:
|
def _process_and_write_query_results(config: dict) -> None:
|
||||||
|
start_time = time.time()
|
||||||
test_output_folder, questions = _initialize_files(config)
|
test_output_folder, questions = _initialize_files(config)
|
||||||
print("saving test results to folder:", test_output_folder)
|
print("saving test results to folder:", test_output_folder)
|
||||||
|
|
||||||
@@ -101,33 +130,26 @@ def _process_and_write_query_results(config: dict) -> None:
|
|||||||
|
|
||||||
if config["limit"] is not None:
|
if config["limit"] is not None:
|
||||||
questions = questions[: config["limit"]]
|
questions = questions[: config["limit"]]
|
||||||
count = 1
|
|
||||||
for question_data in questions:
|
|
||||||
print(f"On question number {count}")
|
|
||||||
|
|
||||||
query = question_data["question"]
|
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
|
||||||
print(f"query: {query}")
|
results = pool.starmap(
|
||||||
context_data_list, answer = get_answer_from_query(
|
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
|
||||||
query=query,
|
|
||||||
run_suffix=config["run_suffix"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not context_data_list:
|
_populate_results_file(test_output_folder, results)
|
||||||
print("No answer or context found")
|
|
||||||
else:
|
|
||||||
print(f"answer: {answer[:50]}...")
|
|
||||||
print(f"{len(context_data_list)} context docs found")
|
|
||||||
print("\n")
|
|
||||||
|
|
||||||
output = {
|
valid_answer_count = 0
|
||||||
"question_data": question_data,
|
for result in results:
|
||||||
"answer": answer,
|
if result.get("answer"):
|
||||||
"context_data_list": context_data_list,
|
valid_answer_count += 1
|
||||||
}
|
|
||||||
|
|
||||||
_update_results_file(test_output_folder, output)
|
_update_metadata_file(test_output_folder, valid_answer_count)
|
||||||
_update_metadata_file(test_output_folder, count)
|
|
||||||
count += 1
|
time_to_finish = time.time() - start_time
|
||||||
|
minutes, seconds = divmod(int(time_to_finish), 60)
|
||||||
|
print(
|
||||||
|
f"Took {minutes:02d}:{seconds:02d} to ask and answer {len(results)} questions"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
||||||
|
Reference in New Issue
Block a user