mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
Relari Test Script (#1033)
This commit is contained in:
@@ -181,7 +181,7 @@ def process_answer(
|
|||||||
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
|
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
|
||||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||||
quotes = match_quotes_to_docs(quote_strings, chunks)
|
quotes = match_quotes_to_docs(quote_strings, chunks)
|
||||||
logger.info(f"Final quotes: {quotes}")
|
logger.debug(f"Final quotes: {quotes}")
|
||||||
|
|
||||||
return DanswerAnswer(answer=answer), quotes
|
return DanswerAnswer(answer=answer), quotes
|
||||||
|
|
||||||
|
@@ -24,6 +24,9 @@ from shared_models.model_server_models import IntentResponse
|
|||||||
from shared_models.model_server_models import RerankRequest
|
from shared_models.model_server_models import RerankRequest
|
||||||
from shared_models.model_server_models import RerankResponse
|
from shared_models.model_server_models import RerankResponse
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
# Remove useless info about layer initialization
|
# Remove useless info about layer initialization
|
||||||
logging.getLogger("transformers").setLevel(logging.ERROR)
|
logging.getLogger("transformers").setLevel(logging.ERROR)
|
||||||
|
116
backend/tests/regression/answer_quality/relari.py
Normal file
116
backend/tests/regression/answer_quality/relari.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.constants import MessageType
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||||
|
from danswer.one_shot_answer.models import DirectQARequest
|
||||||
|
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||||
|
from danswer.one_shot_answer.models import ThreadMessage
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
|
from danswer.search.models import OptionalSearchSetting
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_for_question(query: str, db_session: Session) -> OneShotQAResponse:
|
||||||
|
filters = IndexFilters(
|
||||||
|
source_type=None,
|
||||||
|
document_set=None,
|
||||||
|
time_cutoff=None,
|
||||||
|
access_control_list=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||||
|
|
||||||
|
new_message_request = DirectQARequest(
|
||||||
|
messages=messages,
|
||||||
|
prompt_id=0,
|
||||||
|
persona_id=0,
|
||||||
|
retrieval_options=RetrievalDetails(
|
||||||
|
run_search=OptionalSearchSetting.ALWAYS,
|
||||||
|
real_time=True,
|
||||||
|
filters=filters,
|
||||||
|
enable_auto_detect_filters=False,
|
||||||
|
),
|
||||||
|
chain_of_thought=False,
|
||||||
|
return_contexts=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = get_search_answer(
|
||||||
|
query_req=new_message_request,
|
||||||
|
user=None,
|
||||||
|
db_session=db_session,
|
||||||
|
answer_generation_timeout=100,
|
||||||
|
enable_reflexion=False,
|
||||||
|
bypass_acl=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
def read_questions(questions_file_path: str) -> list[dict]:
|
||||||
|
samples = []
|
||||||
|
with open(questions_file_path, "r", encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
sample = json.loads(line.strip())
|
||||||
|
samples.append(sample)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def main(questions_file: str, output_file: str, limit: int | None = None) -> None:
|
||||||
|
samples = read_questions(questions_file)
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
samples = samples[:limit]
|
||||||
|
|
||||||
|
response_dicts = []
|
||||||
|
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||||
|
for sample in samples:
|
||||||
|
answer = get_answer_for_question(
|
||||||
|
query=sample["question"], db_session=db_session
|
||||||
|
)
|
||||||
|
assert answer.contexts
|
||||||
|
|
||||||
|
response_dict = {
|
||||||
|
"question": sample["question"],
|
||||||
|
"retrieved_contexts": [
|
||||||
|
context.content for context in answer.contexts.contexts
|
||||||
|
],
|
||||||
|
"ground_truth_contexts": sample["ground_truth_contexts"],
|
||||||
|
"answer": answer.answer,
|
||||||
|
"ground_truths": sample["ground_truths"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response_dicts.append(response_dict)
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as out_file:
|
||||||
|
for response_dict in response_dicts:
|
||||||
|
json_line = json.dumps(response_dict)
|
||||||
|
out_file.write(json_line + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--questions_file",
|
||||||
|
type=str,
|
||||||
|
help="Path to the Relari questions file.",
|
||||||
|
default="./tests/regression/answer_quality/combined_golden_dataset.jsonl",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file",
|
||||||
|
type=str,
|
||||||
|
help="Path to the output results file.",
|
||||||
|
default="./tests/regression/answer_quality/relari_results.txt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Limit the number of examples to process.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.questions_file, args.output_file, args.limit)
|
Reference in New Issue
Block a user