mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
Catch dropped eval questions and added multiprocessing (#1849)
This commit is contained in:
parent
25b3dacaba
commit
b83f435bb0
@ -33,6 +33,7 @@ def translate_doc_response_to_simple_doc(
|
||||
) -> list[SimpleDoc]:
|
||||
return [
|
||||
SimpleDoc(
|
||||
id=doc.document_id,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
link=doc.link,
|
||||
blurb=doc.blurb,
|
||||
|
@ -44,6 +44,7 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
blurb: str
|
||||
|
@ -2,23 +2,43 @@ import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
||||
from tests.regression.answer_quality.cli_utils import restart_vespa_container
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def _api_url_builder(run_suffix: str, api_path: str) -> str:
|
||||
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
|
||||
|
||||
|
||||
def _create_new_chat_session(run_suffix: str) -> int:
|
||||
create_chat_request = ChatSessionCreationRequest(
|
||||
persona_id=0,
|
||||
description=None,
|
||||
)
|
||||
body = create_chat_request.dict()
|
||||
|
||||
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
|
||||
|
||||
response_json = requests.post(
|
||||
create_chat_url, headers=GENERAL_HEADERS, json=body
|
||||
).json()
|
||||
chat_session_id = response_json.get("chat_session_id")
|
||||
|
||||
if isinstance(chat_session_id, int):
|
||||
return chat_session_id
|
||||
else:
|
||||
raise RuntimeError(response_json)
|
||||
|
||||
|
||||
@retry(tries=15, delay=10, jitter=1)
|
||||
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
filters = IndexFilters(
|
||||
@ -28,51 +48,43 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
tags=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
chain_of_thought=False,
|
||||
return_contexts=True,
|
||||
retrieval_options = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
)
|
||||
|
||||
url = _api_url_builder(run_suffix, "/query/answer-with-quote/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
chat_session_id = _create_new_chat_session(run_suffix)
|
||||
|
||||
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
|
||||
|
||||
new_message_request = BasicCreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
message=query,
|
||||
retrieval_options=retrieval_options,
|
||||
query_override=query,
|
||||
)
|
||||
|
||||
body = new_message_request.dict()
|
||||
body["user"] = None
|
||||
try:
|
||||
response_json = requests.post(url, headers=headers, json=body).json()
|
||||
context_data_list = response_json.get("contexts", {}).get("contexts", [])
|
||||
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
|
||||
simple_search_docs = response_json.get("simple_search_docs", [])
|
||||
answer = response_json.get("answer", "")
|
||||
except Exception as e:
|
||||
print("Failed to answer the questions:")
|
||||
print(f"\t {str(e)}")
|
||||
print("Restarting vespa container and trying agian")
|
||||
restart_vespa_container(run_suffix)
|
||||
print("trying again")
|
||||
raise e
|
||||
|
||||
return context_data_list, answer
|
||||
return simple_search_docs, answer
|
||||
|
||||
|
||||
def check_if_query_ready(run_suffix: str) -> bool:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
indexing_status_dict = requests.get(url, headers=headers).json()
|
||||
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
|
||||
|
||||
ongoing_index_attempts = False
|
||||
doc_count = 0
|
||||
@ -94,17 +106,13 @@ def check_if_query_ready(run_suffix: str) -> bool:
|
||||
|
||||
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = {
|
||||
"connector_id": connector_id,
|
||||
"credential_ids": [credential_id],
|
||||
"from_beginning": True,
|
||||
}
|
||||
print("body:", body)
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
print("Connector created successfully:", response.json())
|
||||
else:
|
||||
@ -116,13 +124,10 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
||||
url = _api_url_builder(
|
||||
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = {"name": "zip_folder_contents", "is_public": True}
|
||||
print("body:", body)
|
||||
response = requests.put(url, headers=headers, json=body)
|
||||
response = requests.put(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
print("Connector created successfully:", response.json())
|
||||
else:
|
||||
@ -132,14 +137,12 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
||||
|
||||
def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
url = _api_url_builder(run_suffix, "/manage/connector")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
}
|
||||
response = requests.get(url, headers=headers, json=body)
|
||||
response = requests.get(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
connectors = response.json()
|
||||
return [connector["name"] for connector in connectors]
|
||||
@ -149,9 +152,6 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
|
||||
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
connector_name = base_connector_name = "search_eval_connector"
|
||||
existing_connector_names = _get_existing_connector_names(run_suffix)
|
||||
|
||||
@ -172,7 +172,7 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
|
||||
body = connector.dict()
|
||||
print("body:", body)
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
print("Connector created successfully:", response.json())
|
||||
return response.json()["id"]
|
||||
@ -182,14 +182,11 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
|
||||
def create_credential(run_suffix: str) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/credential")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
body = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
print("credential created successfully:", response.json())
|
||||
return response.json()["id"]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
@ -13,11 +14,12 @@ RESULTS_FILENAME = "results.jsonl"
|
||||
METADATA_FILENAME = "metadata.yaml"
|
||||
|
||||
|
||||
def _update_results_file(output_folder_path: str, qa_output: dict) -> None:
|
||||
def _populate_results_file(output_folder_path: str, all_qa_output: list[dict]) -> None:
|
||||
output_file_path = os.path.join(output_folder_path, RESULTS_FILENAME)
|
||||
with open(output_file_path, "w", encoding="utf-8") as file:
|
||||
file.write(json.dumps(qa_output) + "\n")
|
||||
file.flush()
|
||||
with open(output_file_path, "a", encoding="utf-8") as file:
|
||||
for qa_output in all_qa_output:
|
||||
file.write(json.dumps(qa_output) + "\n")
|
||||
file.flush()
|
||||
|
||||
|
||||
def _update_metadata_file(test_output_folder: str, count: int) -> None:
|
||||
@ -81,8 +83,8 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
||||
del env_vars["ENV_SEED_CONFIGURATION"]
|
||||
if env_vars["GPG_KEY"]:
|
||||
del env_vars["GPG_KEY"]
|
||||
if metadata["config"]["llm"]["api_key"]:
|
||||
del metadata["config"]["llm"]["api_key"]
|
||||
if metadata["test_config"]["llm"]["api_key"]:
|
||||
del metadata["test_config"]["llm"]["api_key"]
|
||||
metadata.update(env_vars)
|
||||
metadata_path = os.path.join(test_output_folder, METADATA_FILENAME)
|
||||
print("saving metadata to:", metadata_path)
|
||||
@ -92,7 +94,34 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
||||
return test_output_folder, questions
|
||||
|
||||
|
||||
def _process_question(question_data: dict, config: dict, question_number: int) -> dict:
|
||||
print(f"On question number {question_number}")
|
||||
|
||||
query = question_data["question"]
|
||||
print(f"query: {query}")
|
||||
context_data_list, answer = get_answer_from_query(
|
||||
query=query,
|
||||
run_suffix=config["run_suffix"],
|
||||
)
|
||||
|
||||
if not context_data_list:
|
||||
print("No answer or context found")
|
||||
else:
|
||||
print(f"answer: {answer[:50]}...")
|
||||
print(f"{len(context_data_list)} context docs found")
|
||||
print("\n")
|
||||
|
||||
output = {
|
||||
"question_data": question_data,
|
||||
"answer": answer,
|
||||
"context_data_list": context_data_list,
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _process_and_write_query_results(config: dict) -> None:
|
||||
start_time = time.time()
|
||||
test_output_folder, questions = _initialize_files(config)
|
||||
print("saving test results to folder:", test_output_folder)
|
||||
|
||||
@ -101,33 +130,26 @@ def _process_and_write_query_results(config: dict) -> None:
|
||||
|
||||
if config["limit"] is not None:
|
||||
questions = questions[: config["limit"]]
|
||||
count = 1
|
||||
for question_data in questions:
|
||||
print(f"On question number {count}")
|
||||
|
||||
query = question_data["question"]
|
||||
print(f"query: {query}")
|
||||
context_data_list, answer = get_answer_from_query(
|
||||
query=query,
|
||||
run_suffix=config["run_suffix"],
|
||||
with multiprocessing.Pool(processes=multiprocessing.cpu_count() * 2) as pool:
|
||||
results = pool.starmap(
|
||||
_process_question, [(q, config, i + 1) for i, q in enumerate(questions)]
|
||||
)
|
||||
|
||||
if not context_data_list:
|
||||
print("No answer or context found")
|
||||
else:
|
||||
print(f"answer: {answer[:50]}...")
|
||||
print(f"{len(context_data_list)} context docs found")
|
||||
print("\n")
|
||||
_populate_results_file(test_output_folder, results)
|
||||
|
||||
output = {
|
||||
"question_data": question_data,
|
||||
"answer": answer,
|
||||
"context_data_list": context_data_list,
|
||||
}
|
||||
valid_answer_count = 0
|
||||
for result in results:
|
||||
if result.get("answer"):
|
||||
valid_answer_count += 1
|
||||
|
||||
_update_results_file(test_output_folder, output)
|
||||
_update_metadata_file(test_output_folder, count)
|
||||
count += 1
|
||||
_update_metadata_file(test_output_folder, valid_answer_count)
|
||||
|
||||
time_to_finish = time.time() - start_time
|
||||
minutes, seconds = divmod(int(time_to_finish), 60)
|
||||
print(
|
||||
f"Took {minutes:02d}:{seconds:02d} to ask and answer {len(results)} questions"
|
||||
)
|
||||
|
||||
|
||||
def run_qa_test_and_save_results(run_suffix: str = "") -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user