From ae2a1d3121124971dc4be52ab5b92b22d7141ceb Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 12 May 2023 15:47:58 -0700 Subject: [PATCH] DAN-71 Give back all ranked results from search (#34) --- backend/danswer/direct_qa/question_answer.py | 2 +- .../semantic_search/semantic_search.py | 11 ++-- backend/danswer/server/models.py | 8 +++ backend/danswer/server/search_backend.py | 51 +++++++++++-------- backend/scripts/simulate_frontend.py | 17 +++---- 5 files changed, 51 insertions(+), 38 deletions(-) diff --git a/backend/danswer/direct_qa/question_answer.py b/backend/danswer/direct_qa/question_answer.py index 65ded226c..22ad57859 100644 --- a/backend/danswer/direct_qa/question_answer.py +++ b/backend/danswer/direct_qa/question_answer.py @@ -260,7 +260,7 @@ class OpenAICompletionQA(QAModel): if stream_answer_end(model_previous, event_text): found_answer_end = True continue - yield {"answer data": event_text} + yield {"answer_data": event_text} except Exception as e: logger.exception(e) diff --git a/backend/danswer/semantic_search/semantic_search.py b/backend/danswer/semantic_search/semantic_search.py index daaf0ee0e..e175736ec 100644 --- a/backend/danswer/semantic_search/semantic_search.py +++ b/backend/danswer/semantic_search/semantic_search.py @@ -19,7 +19,6 @@ import json from typing import List from danswer.chunking.models import InferenceChunk -from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE from danswer.configs.model_configs import CROSS_ENCODER_MODEL @@ -67,7 +66,6 @@ def warm_up_models() -> None: def semantic_reranking( query: str, chunks: List[InferenceChunk], - filtered_result_set_size: int = NUM_RERANKED_RESULTS, ) -> List[InferenceChunk]: cross_encoder = get_default_reranking_model() sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore @@ -75,11 +73,9 @@ def semantic_reranking( scored_results.sort(key=lambda x: x[0], reverse=True) ranked_sim_scores, ranked_chunks = zip(*scored_results) - logger.debug( - f"Reranked similarity scores: {str(ranked_sim_scores[:filtered_result_set_size])}" - ) + logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}") - return ranked_chunks[:filtered_result_set_size] + return ranked_chunks @log_function_time() @@ -88,7 +84,6 @@ def retrieve_ranked_documents( filters: list[DatastoreFilter] | None, datastore: Datastore, num_hits: int = NUM_RETURNED_HITS, - filtered_result_set_size: int = NUM_RERANKED_RESULTS, ) -> List[InferenceChunk] | None: top_chunks = datastore.semantic_retrieval(query, filters, num_hits) if not top_chunks: @@ -97,7 +92,7 @@ def retrieve_ranked_documents( f"Semantic search returned no results with filters: {filters_log_msg}" ) return None - ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size) + ranked_chunks = semantic_reranking(query, top_chunks) top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks] files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}" diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index bf0ce3a3e..704c3a3c4 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -10,6 +10,13 @@ class UserRoleResponse(BaseModel): role: str +class SearchDoc(BaseModel): + semantic_name: str + link: str + blurb: str + source_type: str + + class QAQuestion(BaseModel): query: str collection: str @@ -19,6 +26,7 @@ class QAQuestion(BaseModel): class QAResponse(BaseModel): answer: str | None quotes: dict[str, dict[str, str | int | None]] | None + ranked_documents: list[SearchDoc] | None class KeywordResponse(BaseModel): diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 3bbbf98b6..35678de8c 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -5,6 +5,7 @@ from danswer.auth.schemas import UserRole from danswer.auth.users import current_active_user from danswer.auth.users import current_admin_user from danswer.configs.app_configs import KEYWORD_MAX_HITS +from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.constants import CONTENT from danswer.configs.constants import SOURCE_LINKS from danswer.datastores import create_datastore @@ -16,6 +17,7 @@ from danswer.semantic_search.semantic_search import retrieve_ranked_documents from danswer.server.models import KeywordResponse from danswer.server.models import QAQuestion from danswer.server.models import QAResponse +from danswer.server.models import SearchDoc from danswer.server.models import ServerStatus from danswer.server.models import UserByEmail from danswer.server.models import UserRoleResponse @@ -81,7 +83,7 @@ async def promote_admin( @router.get("/direct-qa", response_model=QAResponse) -def direct_qa(question: QAQuestion = Depends()): +def direct_qa(question: QAQuestion = Depends()) -> QAResponse: start_time = time.time() query = question.query @@ -93,21 +95,31 @@ def direct_qa(question: QAQuestion = Depends()): query, filters, create_datastore(collection) ) if not ranked_chunks: - return {"answer": None, "quotes": None} + return QAResponse(answer=None, quotes=None, ranked_documents=None) + + top_docs = [ + SearchDoc( + semantic_name=chunk.semantic_identifier, + link=chunk.source_links.get("0") if chunk.source_links else None, + blurb=chunk.blurb, + source_type=chunk.source_type, + ) + for chunk in ranked_chunks + ] qa_model = get_default_backend_qa_model() - answer, quotes = qa_model.answer_question(query, ranked_chunks) + answer, quotes = qa_model.answer_question( + query, ranked_chunks[:NUM_RERANKED_RESULTS] + ) logger.info(f"Total QA took {time.time() - start_time} seconds") - return QAResponse(answer=answer, quotes=quotes) + return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs) @router.get("/stream-direct-qa") def stream_direct_qa(question: QAQuestion = Depends()): top_documents_key = "top_documents" - answer_key = "answer" - quotes_key = "quotes" def stream_qa_portions(): query = question.query @@ -120,26 +132,25 @@ def stream_direct_qa(question: QAQuestion = Depends()): ) if not ranked_chunks: return yield_json_line( - {top_documents_key: None, answer_key: None, quotes_key: None} + QAResponse(answer=None, quotes=None, ranked_documents=None) ) - linked_chunks = [ - chunk + top_docs = [ + SearchDoc( + semantic_name=chunk.semantic_identifier, + link=chunk.source_links.get("0") if chunk.source_links else None, + blurb=chunk.blurb, + source_type=chunk.source_type, + ) for chunk in ranked_chunks - if chunk.source_links and "0" in chunk.source_links ] - top_docs = { - top_documents_key: { - "document section links": [ - chunk.source_links["0"] for chunk in linked_chunks - ], - "blurbs": [chunk.blurb for chunk in linked_chunks], - } - } - yield yield_json_line(top_docs) + top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]} + yield yield_json_line(top_docs_dict) qa_model = get_default_backend_qa_model() - for response_dict in qa_model.answer_question_stream(query, ranked_chunks): + for response_dict in qa_model.answer_question_stream( + query, ranked_chunks[:NUM_RERANKED_RESULTS] + ): logger.debug(response_dict) yield yield_json_line(response_dict) diff --git a/backend/scripts/simulate_frontend.py b/backend/scripts/simulate_frontend.py index 1afd503a6..98477e98f 100644 --- a/backend/scripts/simulate_frontend.py +++ b/backend/scripts/simulate_frontend.py @@ -1,5 +1,6 @@ import argparse import json +import urllib import requests from danswer.configs.app_configs import APP_PORT @@ -78,7 +79,9 @@ if __name__ == "__main__": "filters": [{SOURCE_TYPE: source_types}], } if not args.stream: - response = requests.get(endpoint, json=query_json) + response = requests.get( + endpoint, params=urllib.parse.urlencode(query_json) + ) contents = json.loads(response.content) if keyword_search: if contents["results"]: @@ -106,15 +109,11 @@ if __name__ == "__main__": else: print("No quotes found") else: - answer = "" - with requests.get(endpoint, json=query_json, stream=True) as r: + with requests.get( + endpoint, params=urllib.parse.urlencode(query_json), stream=True + ) as r: for json_response in r.iter_lines(): - response_dict = json.loads(json_response.decode()) - if "answer data" not in response_dict: - print(response_dict) - else: - answer += response_dict["answer data"] - print(answer) + print(json.loads(json_response.decode())) except Exception as e: print(f"Failed due to {e}, retrying")